preparing for SpeechX extensions
This commit is contained in:
parent
ced31fd9b7
commit
2a71486cb6
|
@ -137,6 +137,8 @@ class Model:
|
|||
name: str = ""
|
||||
size: str = "full"
|
||||
resp_levels: int = 1
|
||||
prom_levels: int = 8
|
||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
arch_type: str = "transformer"
|
||||
|
||||
@property
|
||||
|
@ -157,7 +159,7 @@ class Model:
|
|||
if self.arch_type != "transformer":
|
||||
name.append(self.arch_type.replace("/", "-"))
|
||||
|
||||
name.append(f'{cfg.models.levels}')
|
||||
name.append(f'{cfg.models.prom_levels}')
|
||||
|
||||
return "-".join(name)
|
||||
|
||||
|
@ -192,8 +194,8 @@ class Model:
|
|||
@dataclass()
|
||||
class Models:
|
||||
_models: list[Model] = field(default_factory=lambda: [
|
||||
Model(name="ar", resp_levels=1),
|
||||
Model(name="nar", resp_levels=7),
|
||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=1),
|
||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=1),
|
||||
])
|
||||
|
||||
def get(self, name=None):
|
||||
|
@ -215,11 +217,19 @@ class Models:
|
|||
return self.get("nar")
|
||||
|
||||
@property
|
||||
def levels(self):
|
||||
return self.prom_levels
|
||||
|
||||
prom_levels: int = 8
|
||||
def prom_levels(self):
|
||||
prom_levels = 1
|
||||
for model in self._models:
|
||||
prom_levels = max(prom_levels, model.prom_levels)
|
||||
return prom_levels
|
||||
|
||||
@property
|
||||
def tasks(self):
|
||||
tasks = 1
|
||||
for model in self._models:
|
||||
tasks = max(tasks, model.tasks)
|
||||
return tasks
|
||||
|
||||
@dataclass()
|
||||
class Hyperparameters:
|
||||
batch_size: int = 8
|
||||
|
@ -246,11 +256,9 @@ class Evaluation:
|
|||
class DeepSpeed:
|
||||
zero_optimization_level: int = 0
|
||||
use_compression_training: bool = False
|
||||
compression_bits: int = 8
|
||||
|
||||
def get_ds_cfg(self, model):
|
||||
weights = [ name[0] for name in model.named_parameters() ]
|
||||
bits = 8
|
||||
|
||||
scheduler_params = {}
|
||||
for k in cfg.hyperparameters.scheduler_params:
|
||||
scheduler_params[k] = cfg.hyperparameters.scheduler_params[k]
|
||||
|
@ -298,30 +306,17 @@ class DeepSpeed:
|
|||
"different_groups": {
|
||||
"wq1": {
|
||||
"params": {
|
||||
"start_bits": bits,
|
||||
"target_bits": bits,
|
||||
"start_bits": self.compression_bits,
|
||||
"target_bits": self.compression_bits,
|
||||
"quantization_period": 0
|
||||
},
|
||||
"modules": weights
|
||||
"modules": [
|
||||
"blocks",
|
||||
"retnet",
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"activation_quantization": {
|
||||
"shared_parameters":{
|
||||
"enabled": True,
|
||||
"quantization_type": "symmetric",
|
||||
"range_calibration": "dynamic",
|
||||
"schedule_offset": 0
|
||||
},
|
||||
"different_groups": {
|
||||
"aq1": {
|
||||
"params": {
|
||||
"bits": bits
|
||||
},
|
||||
"modules": weights
|
||||
}
|
||||
}
|
||||
}
|
||||
} if self.use_compression_training else None,
|
||||
"zero_optimization": {
|
||||
"stage": self.zero_optimization_level,
|
||||
|
@ -467,6 +462,8 @@ try:
|
|||
|
||||
# cached_property stopped working...
|
||||
if cfg.dataset.use_hdf5:
|
||||
if cfg.distributed:
|
||||
cfg.dataset.hdf5_flag = "r"
|
||||
try:
|
||||
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||
except Exception as e:
|
||||
|
|
|
@ -51,7 +51,7 @@ def _get_phone_path(path):
|
|||
|
||||
def _load_quants(path) -> Tensor:
|
||||
path = _get_quant_path(path)
|
||||
return torch.load(path)[0][:cfg.models.levels, :].t().to(torch.int16)
|
||||
return torch.load(path)[0][:cfg.models.prom_levels, :].t().to(torch.int16)
|
||||
|
||||
|
||||
@cache
|
||||
|
@ -118,7 +118,6 @@ class Dataset(_Dataset):
|
|||
max_duration=cfg.dataset.duration_range[1],
|
||||
training=False,
|
||||
extra_paths_by_spkr_name: dict[str, list] = {},
|
||||
sample_type=cfg.dataset.sample_type # path | speaker
|
||||
):
|
||||
super().__init__()
|
||||
self._head = None
|
||||
|
@ -126,7 +125,7 @@ class Dataset(_Dataset):
|
|||
self.max_phones = max_phones
|
||||
self.min_duration = min_duration
|
||||
self.max_duration = max_duration
|
||||
self.sample_type = sample_type
|
||||
self.sampler = None
|
||||
|
||||
if cfg.dataset.validate:
|
||||
self.paths = [
|
||||
|
@ -149,6 +148,9 @@ class Dataset(_Dataset):
|
|||
p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
|
||||
]
|
||||
|
||||
if cfg.dataset.sample_type == "path":
|
||||
self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)]
|
||||
|
||||
if len(self.paths) == 0 and training:
|
||||
raise ValueError("No valid path is found for training.")
|
||||
|
||||
|
@ -227,7 +229,7 @@ class Dataset(_Dataset):
|
|||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
#qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.levels]).to(torch.int16)
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||
else:
|
||||
qnt = _load_quants(path)
|
||||
|
||||
|
@ -255,7 +257,7 @@ class Dataset(_Dataset):
|
|||
return ["tts"] # "ns", "sr", "tse", "cse", "nse"
|
||||
|
||||
def __getitem__(self, index):
|
||||
if hasattr(self, "sample_type") and self.sample_type == "speaker":
|
||||
if cfg.dataset.sample_type == "speaker":
|
||||
spkr_name = self.spkrs[index]
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
path = random.choice([*set(self.paths_by_spkr_name[spkr_name])])
|
||||
|
@ -267,7 +269,7 @@ class Dataset(_Dataset):
|
|||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(self.text_dtype)
|
||||
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.levels]).to(torch.int16)
|
||||
resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16)
|
||||
else:
|
||||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
||||
resps = _load_quants(path)
|
||||
|
@ -321,11 +323,8 @@ class Dataset(_Dataset):
|
|||
def training_(self, value):
|
||||
self.training = value
|
||||
|
||||
def interleaved_reorder_(self, fn):
|
||||
self.paths = [*_interleaved_reorder(self.paths, fn)]
|
||||
|
||||
def __len__(self):
|
||||
if hasattr(self, "sample_type") and self.sample_type == "speaker":
|
||||
if cfg.dataset.sample_type == "speaker":
|
||||
return min(len(self.spkrs), self._head or len(self.spkrs))
|
||||
return min(len(self.paths), self._head or len(self.paths))
|
||||
|
||||
|
@ -472,7 +471,6 @@ def create_datasets():
|
|||
#extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
|
||||
)
|
||||
|
||||
val_dataset.interleaved_reorder_(cfg.get_spkr)
|
||||
val_dataset.head_(cfg.evaluation.size)
|
||||
|
||||
return train_dataset, val_dataset
|
||||
|
@ -480,12 +478,10 @@ def create_datasets():
|
|||
|
||||
def create_train_val_dataloader():
|
||||
train_dataset, val_dataset = create_datasets()
|
||||
train_dataset.sample_type = cfg.dataset.sample_type #"speaker"
|
||||
|
||||
subtrain_dataset = copy.deepcopy(train_dataset)
|
||||
if subtrain_dataset.sample_type == "path":
|
||||
if cfg.dataset.sample_type == "path":
|
||||
subtrain_dataset.head_(cfg.evaluation.size)
|
||||
subtrain_dataset.interleaved_reorder_(cfg.get_spkr)
|
||||
|
||||
train_dl = _create_dataloader(train_dataset, training=True)
|
||||
val_dl = _create_dataloader(val_dataset, training=False)
|
||||
|
|
|
@ -32,6 +32,10 @@ class AR(Base):
|
|||
def n_prom_levels(self) -> int:
|
||||
return cfg.models.prom_levels
|
||||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
return cfg.models.tasks
|
||||
|
||||
@property
|
||||
def resp_loss_only(self):
|
||||
return False
|
||||
|
|
|
@ -113,6 +113,10 @@ class Base(nn.Module):
|
|||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def resp_loss_only(self):
|
||||
|
@ -120,7 +124,7 @@ class Base(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
n_tokens: int,
|
||||
n_tokens: int = 1024,
|
||||
d_model: int = 512,
|
||||
n_heads: int = 8,
|
||||
n_layers: int = 12,
|
||||
|
@ -132,16 +136,12 @@ class Base(nn.Module):
|
|||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
|
||||
causal = self.causal
|
||||
|
||||
# +1 to include the stop token
|
||||
n_stop_tokens = 1 if self.use_stop_token else 0
|
||||
n_resp_tokens = n_tokens + n_stop_tokens
|
||||
n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task
|
||||
n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) # AR requires a stop token to... know when to stop
|
||||
|
||||
self.text_emb = Embedding(n_tokens, d_model)
|
||||
|
||||
# Here I simply use all prom levels
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_tokens, d_model)
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
@ -152,14 +152,13 @@ class Base(nn.Module):
|
|||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
p_dropout=p_dropout,
|
||||
causal=causal,
|
||||
causal=self.causal,
|
||||
norm_type=self.norm_type,
|
||||
n_levels=self.n_resp_levels,
|
||||
#tention="retention" if self.use_retnet else "attention"
|
||||
) for _ in range(n_layers) ])
|
||||
|
||||
elif self.arch_type == "retnet":
|
||||
self.retnet_config = RetNetConfig(
|
||||
self.retnet = RetNetDecoder(RetNetConfig(
|
||||
vocab_size=n_tokens,
|
||||
decoder_embed_dim=d_model,
|
||||
decoder_retention_heads=n_heads,
|
||||
|
@ -169,13 +168,10 @@ class Base(nn.Module):
|
|||
checkpoint_activations=True,
|
||||
|
||||
chunkwise_recurrent=self.causal,
|
||||
recurrent_chunkwise_size=128,
|
||||
recurrent_chunkwise_size=64,
|
||||
no_output_layer=True,
|
||||
decoder_normalize_before=True,
|
||||
)
|
||||
self.retnet = RetNetDecoder(
|
||||
self.retnet_config
|
||||
)
|
||||
))
|
||||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
||||
|
@ -281,13 +277,14 @@ class Base(nn.Module):
|
|||
)
|
||||
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
device = x.device
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
x = self.sin_emb.add_pe(x)
|
||||
for block in self.blocks:
|
||||
x = block(x, m, quant_levels)
|
||||
elif self.arch_type == "retnet":
|
||||
# to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward
|
||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||
state = self.retnet.get_incremental_state( state, 'prev_state' )
|
||||
|
||||
|
@ -301,7 +298,7 @@ class Base(nn.Module):
|
|||
if any([l == 0 for l in map(len, targ_list)]):
|
||||
raise ValueError("Cannot compute loss given empty targ_list.")
|
||||
|
||||
ignore_sep = torch.tensor(self.ignore_index, device=x.device)
|
||||
ignore_sep = torch.tensor(self.ignore_index, device=device)
|
||||
|
||||
# ignore the prompt when computing loss
|
||||
prom_list = [
|
||||
|
@ -348,11 +345,6 @@ class Base(nn.Module):
|
|||
acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ),
|
||||
precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ),
|
||||
)
|
||||
del targ_list
|
||||
del prom_list
|
||||
del text_prom_list
|
||||
del y_list
|
||||
|
||||
|
||||
# return the entire generated token string
|
||||
if return_all:
|
||||
|
@ -366,9 +358,6 @@ class Base(nn.Module):
|
|||
else:
|
||||
logits = torch.stack([hi[-1] for hi in h_list])
|
||||
ret = Categorical(logits=logits / sampling_temperature).sample()
|
||||
|
||||
del x_list
|
||||
del h_list
|
||||
|
||||
return ret, state
|
||||
|
||||
|
|
|
@ -31,6 +31,10 @@ class NAR(Base):
|
|||
def n_prom_levels(self) -> int:
|
||||
return cfg.models.prom_levels
|
||||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
return cfg.models.tasks
|
||||
|
||||
@property
|
||||
def resp_loss_only(self):
|
||||
return True
|
||||
|
|
|
@ -87,8 +87,8 @@ def run_eval(engines, eval_name, dl):
|
|||
|
||||
# pseudo loss calculation since we don't get the logits during eval
|
||||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||
ref_audio = ref_audio[..., 0:min_length]
|
||||
hyp_audio = hyp_audio[..., 0:min_length]
|
||||
ref_audio = center_crop(ref_audio, min_length) #ref_audio[..., 0:min_length]
|
||||
hyp_audio = center_crop(hyp_audio, min_length) #hyp_audio[..., 0:min_length]
|
||||
try:
|
||||
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
|
||||
except Exception as e:
|
||||
|
@ -141,7 +141,7 @@ def run_eval(engines, eval_name, dl):
|
|||
|
||||
iteration = engines.global_step
|
||||
engines_stats['it'] = iteration
|
||||
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||
|
||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
|
||||
|
|
2
vall_e/utils/sampler.py
Normal file
2
vall_e/utils/sampler.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
class Sampler():
|
||||
...
|
|
@ -81,9 +81,21 @@ def load_engines():
|
|||
if cfg.trainer.load_state_dict:
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state = torch.load(load_path)
|
||||
# exporting the model from the zero_to_fp32.py exports the actual module's dict
|
||||
# exporting with vall_e.export exports the state dict under .module
|
||||
if "module" in state:
|
||||
state = state["module"]
|
||||
model.load_state_dict(state)
|
||||
|
||||
print(model.proms_emb.weight.shape, state['proms_emb.weight'].shape)
|
||||
|
||||
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
||||
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
|
||||
n_prom_levels, n_prom_tokens, d_model = state['proms_emb.weight'].shape
|
||||
|
||||
model.proms_emb.weight.data[:n_prom_levels, :n_prom_tokens, :] = state['proms_emb.weight'].data[:n_prom_levels, :n_prom_tokens, :]
|
||||
state['proms_emb.weight'] = model.proms_emb.weight
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
engines[name] = Engine(
|
||||
model=model,
|
||||
|
|
Loading…
Reference in New Issue
Block a user