with the amount of tweaks I keep making I could have probably had the nvidia/audio-codec-44khz model realized already......

This commit is contained in:
mrq 2025-02-13 18:38:40 -06:00
parent e3becec0e8
commit a65c8144f4
6 changed files with 100 additions and 30 deletions

View File

@ -414,6 +414,13 @@ class Model:
return 16
return 12
@property
def ffn(self):
if isinstance(self.size, dict) and hasattr(self.size, "ffn"):
return self.size['ffn']
return 4
@property
def activation_checkpointing(self):
return cfg.trainer.activation_checkpointing

View File

@ -396,11 +396,12 @@ def load_engines(training=True, **model_kwargs):
if cfg.lora is not None:
key_name = cfg.lora.full_name
kwargs['id'] = 'job'
salt = "run"
kwargs['id'] = f'{key_name}-{salt}'
kwargs['resume'] = 'allow'
if world_size() > 1:
kwargs["group"] = "DDP"
kwargs['id'] = f'job-{global_rank()}'
kwargs['id'] = f'{key_name}-{salt}-{global_rank()}'
engine.wandb = wandb.init(project=key_name, **kwargs)

View File

@ -37,6 +37,7 @@ import time
import torch
import torch.distributed
import os
import re
from torch import Tensor
from torch.distributed import all_reduce
@ -597,6 +598,22 @@ class Engines(dict[str, Engine]):
if engine.wandb is not None:
engine.wandb.log(model_stats, step=engine.global_step)
filtered_keys = [ k for k in model_stats.keys() if "[" in k ]
filtered_values = {}
for k in filtered_keys:
v = model_stats[k]
del model_stats[k]
nk = re.sub(r"\[\d+\]", "", k)
if nk not in filtered_values:
filtered_values[nk] = []
filtered_values[nk].append( v )
for k, v in filtered_values.items():
model_stats[k] = sum(v) / len(v)
model_stats = model_stats | dict(
lr=engine.get_lr()[0],
elapsed_time=elapsed_time,

View File

@ -171,7 +171,7 @@ class AR_NAR(Base):
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
# only apply stop token for RVQ level 0
if quant_level <= 0 and timesteps[i] is None:
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None):
# append stop tokens for AR
if task not in text_task:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
@ -1232,7 +1232,7 @@ def example_usage():
'n_text_tokens': cfg.model.text_tokens,
'n_audio_tokens': cfg.model.audio_tokens,
'd_model': 1024, # 256, # 1024, # 1536
'd_model': 1536, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
'n_experts': 1 if not cfg.model else cfg.model.experts,
@ -1254,7 +1254,7 @@ def example_usage():
available_tasks = ["tts-nar"]
model = AR_NAR(**kwargs).to(cfg.device)
steps = 500 // batch_size
steps = 250 // batch_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
@ -1444,7 +1444,7 @@ def example_usage():
"""
for task in available_tasks:
sample("final", task=task)
sample("final", task="tts-nar")
engines.quit()

View File

@ -135,7 +135,7 @@ def _interleave_sequence_flatten( input: list[torch.Tensor] ):
# automagically parses a batch-list and returns it as a list
"""
class Embedding(nn.Embedding):
class Embedding(ml.Embedding):
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0:
return []
@ -192,7 +192,7 @@ class AudioEmbedding_Old(nn.Module):
# array of embeddings
# proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None
@ -223,7 +223,7 @@ class AudioEmbedding(nn.Module):
# array of embeddings
# proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
# further experimentation is needed to see if this actually is useful
self.sums = sums
#
@ -350,7 +350,7 @@ class AudioEncoder(nn.Module):
token_dim: int,
):
super().__init__()
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
@ -538,6 +538,7 @@ class Base(nn.Module):
n_raw_text_tokens: int = 8575,
d_model: int = 512,
d_ffn: int = 4,
n_heads: int = 8,
n_layers: int = 12,
p_dropout: float = 0.1,
@ -735,13 +736,15 @@ class Base(nn.Module):
if self.version >= 6:
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
self.monolithic_audio_encoder = False # monolithic sounds bad
if self.version >= 7:
pd_model = d_model // 4
pd_ffn = pd_model * 4
pd_ffn = pd_model * d_ffn
pd_heads = n_heads // 4
pd_layers = 1
if False:
if self.monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 1, # masked token
n_levels=self.n_resp_levels,
@ -763,7 +766,7 @@ class Base(nn.Module):
self.n_resp_levels,
d_model,
dict(
vocab_size=n_audio_tokens,
vocab_size=n_audio_tokens + 1,
hidden_size=pd_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=pd_ffn,
@ -821,7 +824,7 @@ class Base(nn.Module):
vocab_size=n_vocab,
hidden_size=d_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=d_model*4,
intermediate_size=d_model*d_ffn,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
@ -1134,7 +1137,7 @@ class Base(nn.Module):
inputs[i].append( ( "resp", resps_list[i] ) )
if self.version >= 7:
classifier_level = f"NAR:{quant_level}:{quant_level}"
classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}"
inputs[i].append( ("classifier_level", classifier_level) )
# Audio length prediction task
@ -1530,7 +1533,7 @@ class Base(nn.Module):
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums):
return torch.full_like(input[..., 0], self.ignore_index)
if self.version < 7:
@ -1562,7 +1565,7 @@ class Base(nn.Module):
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 10,
top_k = min(logit.shape[0], 10),
average="micro",
multidim_average="global",
ignore_index = -100
@ -1610,6 +1613,9 @@ class Base(nn.Module):
proms = [ input ] if isinstance(input, torch.Tensor) else input
# iterate over the list to inject their tokens
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
if logits[batch_index].dim() < 3 and token.dim() >= 2:
token = token[..., 0]
elif name == "resp":
# mask found, apply it
if self.version < 7:
@ -1659,9 +1665,24 @@ class Base(nn.Module):
if loss_factor == 0.0:
continue
# cringe way to deduce "requested" level
level = quant_level
for i in range( self.n_resp_levels ):
if classifier_level == f'NAR:{i}:{i}':
level = i
break
if logits[batch_index].dim() < 3:
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
if name == "resp":
name = f'{name}[{quant_level}]'
elif not self.resp_parallel_training:
if name == "resp":
name = f'{name}[{level}]'
sequence = token if token.dim() <= 1 else token[:, level]
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
else:
nlls = []
accs = []
@ -1670,15 +1691,29 @@ class Base(nn.Module):
sequence = token if token.dim() <= 1 else token[:, level]
nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal )
if nll:
nlls.append( nll )
if metrics:
accs.append( metrics )
if name == "resp":
if nll is not None:
if f'{name}[{level}].nll' not in loss:
loss[f'{name}[{level}].nll'] = []
loss[f"{name}[{level}].nll"].append( nll * loss_factor )
if metrics is not None:
if f'{name}[{level}].acc' not in stats:
stats[f'{name}[{level}].acc'] = []
stats[f"{name}[{level}].acc"].append( metrics )
if nlls:
nll = sum(nlls) / len(nlls)
if accs:
accs = sum(accs) / len(accs)
nll = None
metrics = None
else:
if nll:
nlls.append( nll )
if metrics:
accs.append( metrics )
else:
if nlls:
nll = sum(nlls) / len(nlls)
if accs:
accs = sum(accs) / len(accs)
if nll is not None:
if f'{name}.nll' not in loss:
@ -1698,6 +1733,16 @@ class Base(nn.Module):
if logits[batch_index].dim() < 3:
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
elif not self.resp_parallel_training:
# cringe way to deduce "requested" level
level = 0
for i in range( self.n_resp_levels ):
if classifier_level == f'NAR:{i}:{i}':
level = i
break
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
else:
nlls = []
accs = []
@ -1779,7 +1824,7 @@ class Base(nn.Module):
# needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" )
causal_levels = [ "AR:0:0", "stt", "len", "phn" ]
causal_levels = [ "stt", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
# right now limit to new versions because I need to retrain the model for noncausal masks...
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
@ -1800,7 +1845,7 @@ class Base(nn.Module):
if self.version >= 7:
logits = [ logit for logit in logits ]
audio_decoder_levels = [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
audio_decoder_levels = [ f"AR:{i}:{i}" for i in range(self.n_resp_levels) ] + [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ]
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ]
@ -2157,7 +2202,7 @@ if __name__ == "__main__":
if is_from_pretrained:
n_vocab = end - start
embds[k] = torch.nn.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
embds[k].weight[:] = model.embed_tokens.weight[start:end, :]
if classifier_idx >= 0:
@ -2169,7 +2214,7 @@ if __name__ == "__main__":
heads[k].weight[:] = hf_model.lm_head.weight[start:end, :]
else:
embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name]
embds[k] = torch.nn.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
embds[k].load_state_dict({ "weight": embd_weight })
if classifier_idx >= 0:

View File

@ -32,7 +32,7 @@ if cfg.optimizations.bitsandbytes:
Linear = bnb.nn.Linear8bitLt
if cfg.optimizations.embedding:
Embedding = bnb.nn.modules.Embedding
Embedding = bnb.nn.StableEmbedding
"""
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
input,