From e029a8804d1e18acebbb17052e1ea9146d3022cd Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 12 Feb 2025 11:17:00 -0600 Subject: [PATCH] ironically none of this cruft gets the loss lower than the original way --- vall_e/data.py | 4 ++-- vall_e/engines/__init__.py | 5 ++++- vall_e/models/ar_nar.py | 8 ++++++-- vall_e/models/base.py | 23 +++++++++++++---------- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index fcfb439..407eadb 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -795,14 +795,14 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ): if type not in _durations_map: _durations_map[type] = {} - _durations_map[type][f"{key}/{id}"] = duration + _durations_map[type][f"{key}/{id}"] = float(duration) if not validate: return True return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration - return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else [] + return [ f"{key}/{id}" for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else [] def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ): if isinstance(path, str): diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 4ab1315..2ceb836 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -1,6 +1,6 @@ from ..config import cfg -from ..utils.distributed import fix_unset_envs, ddp_model, world_size +from ..utils.distributed import fix_unset_envs, ddp_model, world_size, global_rank fix_unset_envs() if cfg.trainer.backend == "deepspeed": @@ -396,8 +396,11 @@ def load_engines(training=True, **model_kwargs): if cfg.lora is not None: key_name = cfg.lora.full_name + kwargs['name'] = 'job' if world_size() > 1: kwargs["group"] = "DDP" + kwargs['name'] = f'job-{global_rank()}' + engine.wandb = wandb.init(project=key_name, **kwargs) engine.wandb.watch(engine.module) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index cac48f6..654cfde 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -1246,8 +1246,12 @@ def example_usage(): bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) - #available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else []) - available_tasks = ["tts-nar"] + available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else []) + + if cfg.model.experimental.masking_train_p == 0: + available_tasks = ["tts-ar"] + elif cfg.model.experimental.masking_train_p == 1: + available_tasks = ["tts-nar"] model = AR_NAR(**kwargs).to(cfg.device) steps = 500 // batch_size diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 53c746b..88c68ed 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -237,7 +237,7 @@ class AudioEmbedding(nn.Module): offset = self.names.index( name ) offset -= quant_level # offset by quant level since it'll iterate up that many levels - if self.sums and quant_level > 0: + if sums and quant_level > 0: x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] ) else: k = quant_level @@ -379,17 +379,13 @@ class ParallelDecoder(nn.Module): )) modules.append(module) - """ downs.append(nn.Linear(d_model, hidden_size, bias=False)) ups.append(nn.Linear(hidden_size, vocab_size, bias=False)) - """ self.levels = levels self.decoders = nn.ModuleList(modules) - """ self.downs = nn.ModuleList(downs) self.ups = nn.ModuleList(ups) - """ def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor: # split into levels @@ -402,6 +398,7 @@ class ParallelDecoder(nn.Module): # do one level # attention + feedforward + """ x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"] # this really hates an output head, so just treat the final output as one x = x[..., :self.vocab_size] @@ -413,7 +410,6 @@ class ParallelDecoder(nn.Module): x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"] # upscale to vocab logits x = self.ups[level]( x ) - """ return x """ @@ -1281,13 +1277,20 @@ class Base(nn.Module): ) if not self.parallel_decoding: + """ + # provides only one return self.proms_emb( - input, - quant_level = 0 if input.dim() == 1 else input.shape[-1], + input if input.dim() == 1 else input[:, quant_level], + quant_level = 0, # if input.dim() == 1 else input.shape[-1], + offset = 0, + ) + """ + # sums all + return self.proms_emb( + input, + quant_level = quant_level if input.dim() == 1 else input.shape[-1], offset = 0, ) - """ - """ return self.proms_emb( input )