From 2a71486cb65e3e95a543644257961d4c33612375 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 18 Aug 2023 20:58:07 -0500 Subject: [PATCH] preparing for SpeechX extensions --- vall_e/config.py | 55 +++++++++++++++++++---------------------- vall_e/data.py | 24 ++++++++---------- vall_e/models/ar.py | 4 +++ vall_e/models/base.py | 41 +++++++++++------------------- vall_e/models/nar.py | 4 +++ vall_e/train.py | 6 ++--- vall_e/utils/sampler.py | 2 ++ vall_e/utils/trainer.py | 14 ++++++++++- 8 files changed, 77 insertions(+), 73 deletions(-) create mode 100644 vall_e/utils/sampler.py diff --git a/vall_e/config.py b/vall_e/config.py index 1f5b4e7..9318393 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -137,6 +137,8 @@ class Model: name: str = "" size: str = "full" resp_levels: int = 1 + prom_levels: int = 8 + tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") arch_type: str = "transformer" @property @@ -157,7 +159,7 @@ class Model: if self.arch_type != "transformer": name.append(self.arch_type.replace("/", "-")) - name.append(f'{cfg.models.levels}') + name.append(f'{cfg.models.prom_levels}') return "-".join(name) @@ -192,8 +194,8 @@ class Model: @dataclass() class Models: _models: list[Model] = field(default_factory=lambda: [ - Model(name="ar", resp_levels=1), - Model(name="nar", resp_levels=7), + Model(name="ar", resp_levels=1, prom_levels=8, tasks=1), + Model(name="nar", resp_levels=7, prom_levels=8, tasks=1), ]) def get(self, name=None): @@ -215,11 +217,19 @@ class Models: return self.get("nar") @property - def levels(self): - return self.prom_levels - - prom_levels: int = 8 + def prom_levels(self): + prom_levels = 1 + for model in self._models: + prom_levels = max(prom_levels, model.prom_levels) + return prom_levels + @property + def tasks(self): + tasks = 1 + for model in self._models: + tasks = max(tasks, model.tasks) + return tasks + @dataclass() class Hyperparameters: batch_size: int = 8 @@ -246,11 +256,9 @@ class Evaluation: class DeepSpeed: zero_optimization_level: int = 0 use_compression_training: bool = False + compression_bits: int = 8 def get_ds_cfg(self, model): - weights = [ name[0] for name in model.named_parameters() ] - bits = 8 - scheduler_params = {} for k in cfg.hyperparameters.scheduler_params: scheduler_params[k] = cfg.hyperparameters.scheduler_params[k] @@ -298,30 +306,17 @@ class DeepSpeed: "different_groups": { "wq1": { "params": { - "start_bits": bits, - "target_bits": bits, + "start_bits": self.compression_bits, + "target_bits": self.compression_bits, "quantization_period": 0 }, - "modules": weights + "modules": [ + "blocks", + "retnet", + ] } } }, - "activation_quantization": { - "shared_parameters":{ - "enabled": True, - "quantization_type": "symmetric", - "range_calibration": "dynamic", - "schedule_offset": 0 - }, - "different_groups": { - "aq1": { - "params": { - "bits": bits - }, - "modules": weights - } - } - } } if self.use_compression_training else None, "zero_optimization": { "stage": self.zero_optimization_level, @@ -467,6 +462,8 @@ try: # cached_property stopped working... if cfg.dataset.use_hdf5: + if cfg.distributed: + cfg.dataset.hdf5_flag = "r" try: cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', cfg.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset except Exception as e: diff --git a/vall_e/data.py b/vall_e/data.py index 3d7d30f..9bb507c 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -51,7 +51,7 @@ def _get_phone_path(path): def _load_quants(path) -> Tensor: path = _get_quant_path(path) - return torch.load(path)[0][:cfg.models.levels, :].t().to(torch.int16) + return torch.load(path)[0][:cfg.models.prom_levels, :].t().to(torch.int16) @cache @@ -118,7 +118,6 @@ class Dataset(_Dataset): max_duration=cfg.dataset.duration_range[1], training=False, extra_paths_by_spkr_name: dict[str, list] = {}, - sample_type=cfg.dataset.sample_type # path | speaker ): super().__init__() self._head = None @@ -126,7 +125,7 @@ class Dataset(_Dataset): self.max_phones = max_phones self.min_duration = min_duration self.max_duration = max_duration - self.sample_type = sample_type + self.sampler = None if cfg.dataset.validate: self.paths = [ @@ -149,6 +148,9 @@ class Dataset(_Dataset): p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1 ] + if cfg.dataset.sample_type == "path": + self.paths = [*_interleaved_reorder(self.paths, cfg.get_spkr)] + if len(self.paths) == 0 and training: raise ValueError("No valid path is found for training.") @@ -227,7 +229,7 @@ class Dataset(_Dataset): if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) #qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16) - qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.levels]).to(torch.int16) + qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) else: qnt = _load_quants(path) @@ -255,7 +257,7 @@ class Dataset(_Dataset): return ["tts"] # "ns", "sr", "tse", "cse", "nse" def __getitem__(self, index): - if hasattr(self, "sample_type") and self.sample_type == "speaker": + if cfg.dataset.sample_type == "speaker": spkr_name = self.spkrs[index] spkr_id = self.spkr_symmap[spkr_name] path = random.choice([*set(self.paths_by_spkr_name[spkr_name])]) @@ -267,7 +269,7 @@ class Dataset(_Dataset): if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(self.text_dtype) - resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.levels]).to(torch.int16) + resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) else: text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) resps = _load_quants(path) @@ -321,11 +323,8 @@ class Dataset(_Dataset): def training_(self, value): self.training = value - def interleaved_reorder_(self, fn): - self.paths = [*_interleaved_reorder(self.paths, fn)] - def __len__(self): - if hasattr(self, "sample_type") and self.sample_type == "speaker": + if cfg.dataset.sample_type == "speaker": return min(len(self.spkrs), self._head or len(self.spkrs)) return min(len(self.paths), self._head or len(self.paths)) @@ -472,7 +471,6 @@ def create_datasets(): #extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name, ) - val_dataset.interleaved_reorder_(cfg.get_spkr) val_dataset.head_(cfg.evaluation.size) return train_dataset, val_dataset @@ -480,12 +478,10 @@ def create_datasets(): def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() - train_dataset.sample_type = cfg.dataset.sample_type #"speaker" subtrain_dataset = copy.deepcopy(train_dataset) - if subtrain_dataset.sample_type == "path": + if cfg.dataset.sample_type == "path": subtrain_dataset.head_(cfg.evaluation.size) - subtrain_dataset.interleaved_reorder_(cfg.get_spkr) train_dl = _create_dataloader(train_dataset, training=True) val_dl = _create_dataloader(val_dataset, training=False) diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 7ef8158..696e9c8 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -32,6 +32,10 @@ class AR(Base): def n_prom_levels(self) -> int: return cfg.models.prom_levels + @property + def n_tasks(self) -> int: + return cfg.models.tasks + @property def resp_loss_only(self): return False diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 71fd3c2..beed6f8 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -113,6 +113,10 @@ class Base(nn.Module): @property def n_prom_levels(self) -> int: raise NotImplementedError + + @property + def n_tasks(self) -> int: + raise NotImplementedError @property def resp_loss_only(self): @@ -120,7 +124,7 @@ class Base(nn.Module): def __init__( self, - n_tokens: int, + n_tokens: int = 1024, d_model: int = 512, n_heads: int = 8, n_layers: int = 12, @@ -132,16 +136,12 @@ class Base(nn.Module): self.n_heads = n_heads self.n_layers = n_layers - causal = self.causal - # +1 to include the stop token - n_stop_tokens = 1 if self.use_stop_token else 0 - n_resp_tokens = n_tokens + n_stop_tokens + n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task + n_resp_tokens = n_tokens + (1 if self.use_stop_token else 0) # AR requires a stop token to... know when to stop self.text_emb = Embedding(n_tokens, d_model) - - # Here I simply use all prom levels - self.proms_emb = MultiEmbedding(self.n_prom_levels, n_tokens, d_model) + self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) self.sep = nn.Parameter(torch.randn(d_model)) @@ -152,14 +152,13 @@ class Base(nn.Module): d_model=d_model, n_heads=n_heads, p_dropout=p_dropout, - causal=causal, + causal=self.causal, norm_type=self.norm_type, n_levels=self.n_resp_levels, - #tention="retention" if self.use_retnet else "attention" ) for _ in range(n_layers) ]) elif self.arch_type == "retnet": - self.retnet_config = RetNetConfig( + self.retnet = RetNetDecoder(RetNetConfig( vocab_size=n_tokens, decoder_embed_dim=d_model, decoder_retention_heads=n_heads, @@ -169,13 +168,10 @@ class Base(nn.Module): checkpoint_activations=True, chunkwise_recurrent=self.causal, - recurrent_chunkwise_size=128, + recurrent_chunkwise_size=64, no_output_layer=True, decoder_normalize_before=True, - ) - self.retnet = RetNetDecoder( - self.retnet_config - ) + )) self.classifier = nn.Linear(d_model, n_resp_tokens) @@ -281,13 +277,14 @@ class Base(nn.Module): ) x, m = list_to_tensor(x_list) - + device = x.device if self.arch_type == "transformer": x = self.sin_emb.add_pe(x) for block in self.blocks: x = block(x, m, quant_levels) elif self.arch_type == "retnet": + # to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) state = self.retnet.get_incremental_state( state, 'prev_state' ) @@ -301,7 +298,7 @@ class Base(nn.Module): if any([l == 0 for l in map(len, targ_list)]): raise ValueError("Cannot compute loss given empty targ_list.") - ignore_sep = torch.tensor(self.ignore_index, device=x.device) + ignore_sep = torch.tensor(self.ignore_index, device=device) # ignore the prompt when computing loss prom_list = [ @@ -348,11 +345,6 @@ class Base(nn.Module): acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ), precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ), ) - del targ_list - del prom_list - del text_prom_list - del y_list - # return the entire generated token string if return_all: @@ -366,9 +358,6 @@ class Base(nn.Module): else: logits = torch.stack([hi[-1] for hi in h_list]) ret = Categorical(logits=logits / sampling_temperature).sample() - - del x_list - del h_list return ret, state diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 57dc966..757017c 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -31,6 +31,10 @@ class NAR(Base): def n_prom_levels(self) -> int: return cfg.models.prom_levels + @property + def n_tasks(self) -> int: + return cfg.models.tasks + @property def resp_loss_only(self): return True diff --git a/vall_e/train.py b/vall_e/train.py index a53f58a..3a67ea3 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -87,8 +87,8 @@ def run_eval(engines, eval_name, dl): # pseudo loss calculation since we don't get the logits during eval min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) - ref_audio = ref_audio[..., 0:min_length] - hyp_audio = hyp_audio[..., 0:min_length] + ref_audio = center_crop(ref_audio, min_length) #ref_audio[..., 0:min_length] + hyp_audio = center_crop(hyp_audio, min_length) #hyp_audio[..., 0:min_length] try: stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item()) except Exception as e: @@ -141,7 +141,7 @@ def run_eval(engines, eval_name, dl): iteration = engines.global_step engines_stats['it'] = iteration - engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) + #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") diff --git a/vall_e/utils/sampler.py b/vall_e/utils/sampler.py new file mode 100644 index 0000000..ad22dec --- /dev/null +++ b/vall_e/utils/sampler.py @@ -0,0 +1,2 @@ +class Sampler(): + ... \ No newline at end of file diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index ca7af73..2dc849e 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -81,9 +81,21 @@ def load_engines(): if cfg.trainer.load_state_dict: load_path = cfg.ckpt_dir / name / "fp32.pth" state = torch.load(load_path) + # exporting the model from the zero_to_fp32.py exports the actual module's dict + # exporting with vall_e.export exports the state dict under .module if "module" in state: state = state["module"] - model.load_state_dict(state) + + print(model.proms_emb.weight.shape, state['proms_emb.weight'].shape) + + # extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks) + if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]: + n_prom_levels, n_prom_tokens, d_model = state['proms_emb.weight'].shape + + model.proms_emb.weight.data[:n_prom_levels, :n_prom_tokens, :] = state['proms_emb.weight'].data[:n_prom_levels, :n_prom_tokens, :] + state['proms_emb.weight'] = model.proms_emb.weight + + model.load_state_dict(state, strict=cfg.trainer.strict_loading) engines[name] = Engine( model=model,