ironically none of this cruft gets the loss lower than the original way
This commit is contained in:
parent
4b31f5c808
commit
e029a8804d
|
@ -795,14 +795,14 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
|||
|
||||
if type not in _durations_map:
|
||||
_durations_map[type] = {}
|
||||
_durations_map[type][f"{key}/{id}"] = duration
|
||||
_durations_map[type][f"{key}/{id}"] = float(duration)
|
||||
|
||||
if not validate:
|
||||
return True
|
||||
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||
|
||||
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
|
||||
return [ f"{key}/{id}" for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
|
||||
|
||||
def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ):
|
||||
if isinstance(path, str):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from ..config import cfg
|
||||
|
||||
from ..utils.distributed import fix_unset_envs, ddp_model, world_size
|
||||
from ..utils.distributed import fix_unset_envs, ddp_model, world_size, global_rank
|
||||
fix_unset_envs()
|
||||
|
||||
if cfg.trainer.backend == "deepspeed":
|
||||
|
@ -396,8 +396,11 @@ def load_engines(training=True, **model_kwargs):
|
|||
if cfg.lora is not None:
|
||||
key_name = cfg.lora.full_name
|
||||
|
||||
kwargs['name'] = 'job'
|
||||
if world_size() > 1:
|
||||
kwargs["group"] = "DDP"
|
||||
kwargs['name'] = f'job-{global_rank()}'
|
||||
|
||||
|
||||
engine.wandb = wandb.init(project=key_name, **kwargs)
|
||||
engine.wandb.watch(engine.module)
|
||||
|
|
|
@ -1246,8 +1246,12 @@ def example_usage():
|
|||
|
||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||
|
||||
#available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else [])
|
||||
available_tasks = ["tts-nar"]
|
||||
available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else [])
|
||||
|
||||
if cfg.model.experimental.masking_train_p == 0:
|
||||
available_tasks = ["tts-ar"]
|
||||
elif cfg.model.experimental.masking_train_p == 1:
|
||||
available_tasks = ["tts-nar"]
|
||||
|
||||
model = AR_NAR(**kwargs).to(cfg.device)
|
||||
steps = 500 // batch_size
|
||||
|
|
|
@ -237,7 +237,7 @@ class AudioEmbedding(nn.Module):
|
|||
offset = self.names.index( name )
|
||||
offset -= quant_level # offset by quant level since it'll iterate up that many levels
|
||||
|
||||
if self.sums and quant_level > 0:
|
||||
if sums and quant_level > 0:
|
||||
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
||||
else:
|
||||
k = quant_level
|
||||
|
@ -379,17 +379,13 @@ class ParallelDecoder(nn.Module):
|
|||
))
|
||||
|
||||
modules.append(module)
|
||||
"""
|
||||
downs.append(nn.Linear(d_model, hidden_size, bias=False))
|
||||
ups.append(nn.Linear(hidden_size, vocab_size, bias=False))
|
||||
"""
|
||||
|
||||
self.levels = levels
|
||||
self.decoders = nn.ModuleList(modules)
|
||||
"""
|
||||
self.downs = nn.ModuleList(downs)
|
||||
self.ups = nn.ModuleList(ups)
|
||||
"""
|
||||
|
||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||
# split into levels
|
||||
|
@ -402,6 +398,7 @@ class ParallelDecoder(nn.Module):
|
|||
# do one level
|
||||
|
||||
# attention + feedforward
|
||||
"""
|
||||
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
|
||||
# this really hates an output head, so just treat the final output as one
|
||||
x = x[..., :self.vocab_size]
|
||||
|
@ -413,7 +410,6 @@ class ParallelDecoder(nn.Module):
|
|||
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
|
||||
# upscale to vocab logits
|
||||
x = self.ups[level]( x )
|
||||
"""
|
||||
|
||||
return x
|
||||
"""
|
||||
|
@ -1281,13 +1277,20 @@ class Base(nn.Module):
|
|||
)
|
||||
|
||||
if not self.parallel_decoding:
|
||||
"""
|
||||
# provides only one
|
||||
return self.proms_emb(
|
||||
input,
|
||||
quant_level = 0 if input.dim() == 1 else input.shape[-1],
|
||||
input if input.dim() == 1 else input[:, quant_level],
|
||||
quant_level = 0, # if input.dim() == 1 else input.shape[-1],
|
||||
offset = 0,
|
||||
)
|
||||
"""
|
||||
# sums all
|
||||
return self.proms_emb(
|
||||
input,
|
||||
quant_level = quant_level if input.dim() == 1 else input.shape[-1],
|
||||
offset = 0,
|
||||
)
|
||||
"""
|
||||
"""
|
||||
|
||||
return self.proms_emb( input )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user