ironically none of this cruft gets the loss lower than the original way

This commit is contained in:
mrq 2025-02-12 11:17:00 -06:00
parent 4b31f5c808
commit e029a8804d
4 changed files with 25 additions and 15 deletions

View File

@ -795,14 +795,14 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
if type not in _durations_map:
_durations_map[type] = {}
_durations_map[type][f"{key}/{id}"] = duration
_durations_map[type][f"{key}/{id}"] = float(duration)
if not validate:
return True
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
return [ f"{key}/{id}" for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ):
if isinstance(path, str):

View File

@ -1,6 +1,6 @@
from ..config import cfg
from ..utils.distributed import fix_unset_envs, ddp_model, world_size
from ..utils.distributed import fix_unset_envs, ddp_model, world_size, global_rank
fix_unset_envs()
if cfg.trainer.backend == "deepspeed":
@ -396,8 +396,11 @@ def load_engines(training=True, **model_kwargs):
if cfg.lora is not None:
key_name = cfg.lora.full_name
kwargs['name'] = 'job'
if world_size() > 1:
kwargs["group"] = "DDP"
kwargs['name'] = f'job-{global_rank()}'
engine.wandb = wandb.init(project=key_name, **kwargs)
engine.wandb.watch(engine.module)

View File

@ -1246,8 +1246,12 @@ def example_usage():
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
#available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else [])
available_tasks = ["tts-nar"]
available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else [])
if cfg.model.experimental.masking_train_p == 0:
available_tasks = ["tts-ar"]
elif cfg.model.experimental.masking_train_p == 1:
available_tasks = ["tts-nar"]
model = AR_NAR(**kwargs).to(cfg.device)
steps = 500 // batch_size

View File

@ -237,7 +237,7 @@ class AudioEmbedding(nn.Module):
offset = self.names.index( name )
offset -= quant_level # offset by quant level since it'll iterate up that many levels
if self.sums and quant_level > 0:
if sums and quant_level > 0:
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
else:
k = quant_level
@ -379,17 +379,13 @@ class ParallelDecoder(nn.Module):
))
modules.append(module)
"""
downs.append(nn.Linear(d_model, hidden_size, bias=False))
ups.append(nn.Linear(hidden_size, vocab_size, bias=False))
"""
self.levels = levels
self.decoders = nn.ModuleList(modules)
"""
self.downs = nn.ModuleList(downs)
self.ups = nn.ModuleList(ups)
"""
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
# split into levels
@ -402,6 +398,7 @@ class ParallelDecoder(nn.Module):
# do one level
# attention + feedforward
"""
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
# this really hates an output head, so just treat the final output as one
x = x[..., :self.vocab_size]
@ -413,7 +410,6 @@ class ParallelDecoder(nn.Module):
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
# upscale to vocab logits
x = self.ups[level]( x )
"""
return x
"""
@ -1281,13 +1277,20 @@ class Base(nn.Module):
)
if not self.parallel_decoding:
"""
# provides only one
return self.proms_emb(
input,
quant_level = 0 if input.dim() == 1 else input.shape[-1],
input if input.dim() == 1 else input[:, quant_level],
quant_level = 0, # if input.dim() == 1 else input.shape[-1],
offset = 0,
)
"""
# sums all
return self.proms_emb(
input,
quant_level = quant_level if input.dim() == 1 else input.shape[-1],
offset = 0,
)
"""
"""
return self.proms_emb( input )