with the amount of tweaks I keep making I could have probably had the nvidia/audio-codec-44khz model realized already......
This commit is contained in:
parent
e3becec0e8
commit
a65c8144f4
|
@ -414,6 +414,13 @@ class Model:
|
|||
return 16
|
||||
return 12
|
||||
|
||||
@property
|
||||
def ffn(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "ffn"):
|
||||
return self.size['ffn']
|
||||
|
||||
return 4
|
||||
|
||||
@property
|
||||
def activation_checkpointing(self):
|
||||
return cfg.trainer.activation_checkpointing
|
||||
|
|
|
@ -396,11 +396,12 @@ def load_engines(training=True, **model_kwargs):
|
|||
if cfg.lora is not None:
|
||||
key_name = cfg.lora.full_name
|
||||
|
||||
kwargs['id'] = 'job'
|
||||
salt = "run"
|
||||
kwargs['id'] = f'{key_name}-{salt}'
|
||||
kwargs['resume'] = 'allow'
|
||||
if world_size() > 1:
|
||||
kwargs["group"] = "DDP"
|
||||
kwargs['id'] = f'job-{global_rank()}'
|
||||
kwargs['id'] = f'{key_name}-{salt}-{global_rank()}'
|
||||
|
||||
|
||||
engine.wandb = wandb.init(project=key_name, **kwargs)
|
||||
|
|
|
@ -37,6 +37,7 @@ import time
|
|||
import torch
|
||||
import torch.distributed
|
||||
import os
|
||||
import re
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributed import all_reduce
|
||||
|
@ -597,6 +598,22 @@ class Engines(dict[str, Engine]):
|
|||
if engine.wandb is not None:
|
||||
engine.wandb.log(model_stats, step=engine.global_step)
|
||||
|
||||
filtered_keys = [ k for k in model_stats.keys() if "[" in k ]
|
||||
filtered_values = {}
|
||||
for k in filtered_keys:
|
||||
v = model_stats[k]
|
||||
del model_stats[k]
|
||||
|
||||
nk = re.sub(r"\[\d+\]", "", k)
|
||||
|
||||
if nk not in filtered_values:
|
||||
filtered_values[nk] = []
|
||||
|
||||
filtered_values[nk].append( v )
|
||||
|
||||
for k, v in filtered_values.items():
|
||||
model_stats[k] = sum(v) / len(v)
|
||||
|
||||
model_stats = model_stats | dict(
|
||||
lr=engine.get_lr()[0],
|
||||
elapsed_time=elapsed_time,
|
||||
|
|
|
@ -171,7 +171,7 @@ class AR_NAR(Base):
|
|||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_level <= 0 and timesteps[i] is None:
|
||||
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None):
|
||||
# append stop tokens for AR
|
||||
if task not in text_task:
|
||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||
|
@ -1232,7 +1232,7 @@ def example_usage():
|
|||
'n_text_tokens': cfg.model.text_tokens,
|
||||
'n_audio_tokens': cfg.model.audio_tokens,
|
||||
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'd_model': 1536, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
'n_layers': 12, # 32
|
||||
'n_experts': 1 if not cfg.model else cfg.model.experts,
|
||||
|
@ -1254,7 +1254,7 @@ def example_usage():
|
|||
available_tasks = ["tts-nar"]
|
||||
|
||||
model = AR_NAR(**kwargs).to(cfg.device)
|
||||
steps = 500 // batch_size
|
||||
steps = 250 // batch_size
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
@ -1444,7 +1444,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
for task in available_tasks:
|
||||
sample("final", task=task)
|
||||
sample("final", task="tts-nar")
|
||||
|
||||
engines.quit()
|
||||
|
||||
|
|
|
@ -135,7 +135,7 @@ def _interleave_sequence_flatten( input: list[torch.Tensor] ):
|
|||
|
||||
# automagically parses a batch-list and returns it as a list
|
||||
"""
|
||||
class Embedding(nn.Embedding):
|
||||
class Embedding(ml.Embedding):
|
||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
|
@ -192,7 +192,7 @@ class AudioEmbedding_Old(nn.Module):
|
|||
# array of embeddings
|
||||
# proms are [0, resp_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||
|
||||
|
@ -223,7 +223,7 @@ class AudioEmbedding(nn.Module):
|
|||
# array of embeddings
|
||||
# proms are [0, resp_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
|
||||
# further experimentation is needed to see if this actually is useful
|
||||
self.sums = sums
|
||||
#
|
||||
|
@ -350,7 +350,7 @@ class AudioEncoder(nn.Module):
|
|||
token_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
|
@ -538,6 +538,7 @@ class Base(nn.Module):
|
|||
n_raw_text_tokens: int = 8575,
|
||||
|
||||
d_model: int = 512,
|
||||
d_ffn: int = 4,
|
||||
n_heads: int = 8,
|
||||
n_layers: int = 12,
|
||||
p_dropout: float = 0.1,
|
||||
|
@ -735,13 +736,15 @@ class Base(nn.Module):
|
|||
if self.version >= 6:
|
||||
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
||||
|
||||
self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
|
||||
self.monolithic_audio_encoder = False # monolithic sounds bad
|
||||
if self.version >= 7:
|
||||
pd_model = d_model // 4
|
||||
pd_ffn = pd_model * 4
|
||||
pd_ffn = pd_model * d_ffn
|
||||
pd_heads = n_heads // 4
|
||||
pd_layers = 1
|
||||
|
||||
if False:
|
||||
if self.monolithic_audio_encoder:
|
||||
self.audio_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 1, # masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
|
@ -763,7 +766,7 @@ class Base(nn.Module):
|
|||
self.n_resp_levels,
|
||||
d_model,
|
||||
dict(
|
||||
vocab_size=n_audio_tokens,
|
||||
vocab_size=n_audio_tokens + 1,
|
||||
hidden_size=pd_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=pd_ffn,
|
||||
|
@ -821,7 +824,7 @@ class Base(nn.Module):
|
|||
vocab_size=n_vocab,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=d_model*4,
|
||||
intermediate_size=d_model*d_ffn,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
|
@ -1134,7 +1137,7 @@ class Base(nn.Module):
|
|||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
if self.version >= 7:
|
||||
classifier_level = f"NAR:{quant_level}:{quant_level}"
|
||||
classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}"
|
||||
|
||||
inputs[i].append( ("classifier_level", classifier_level) )
|
||||
# Audio length prediction task
|
||||
|
@ -1530,7 +1533,7 @@ class Base(nn.Module):
|
|||
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
||||
|
||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
return torch.full_like(input[..., 0], self.ignore_index)
|
||||
|
||||
if self.version < 7:
|
||||
|
@ -1562,7 +1565,7 @@ class Base(nn.Module):
|
|||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 10,
|
||||
top_k = min(logit.shape[0], 10),
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
|
@ -1610,6 +1613,9 @@ class Base(nn.Module):
|
|||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
# iterate over the list to inject their tokens
|
||||
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
||||
|
||||
if logits[batch_index].dim() < 3 and token.dim() >= 2:
|
||||
token = token[..., 0]
|
||||
elif name == "resp":
|
||||
# mask found, apply it
|
||||
if self.version < 7:
|
||||
|
@ -1659,9 +1665,24 @@ class Base(nn.Module):
|
|||
|
||||
if loss_factor == 0.0:
|
||||
continue
|
||||
|
||||
# cringe way to deduce "requested" level
|
||||
level = quant_level
|
||||
for i in range( self.n_resp_levels ):
|
||||
if classifier_level == f'NAR:{i}:{i}':
|
||||
level = i
|
||||
break
|
||||
|
||||
if logits[batch_index].dim() < 3:
|
||||
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||
|
||||
if name == "resp":
|
||||
name = f'{name}[{quant_level}]'
|
||||
elif not self.resp_parallel_training:
|
||||
if name == "resp":
|
||||
name = f'{name}[{level}]'
|
||||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
||||
else:
|
||||
nlls = []
|
||||
accs = []
|
||||
|
@ -1670,15 +1691,29 @@ class Base(nn.Module):
|
|||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal )
|
||||
|
||||
if nll:
|
||||
nlls.append( nll )
|
||||
if metrics:
|
||||
accs.append( metrics )
|
||||
if name == "resp":
|
||||
if nll is not None:
|
||||
if f'{name}[{level}].nll' not in loss:
|
||||
loss[f'{name}[{level}].nll'] = []
|
||||
loss[f"{name}[{level}].nll"].append( nll * loss_factor )
|
||||
|
||||
if metrics is not None:
|
||||
if f'{name}[{level}].acc' not in stats:
|
||||
stats[f'{name}[{level}].acc'] = []
|
||||
stats[f"{name}[{level}].acc"].append( metrics )
|
||||
|
||||
if nlls:
|
||||
nll = sum(nlls) / len(nlls)
|
||||
if accs:
|
||||
accs = sum(accs) / len(accs)
|
||||
nll = None
|
||||
metrics = None
|
||||
else:
|
||||
if nll:
|
||||
nlls.append( nll )
|
||||
if metrics:
|
||||
accs.append( metrics )
|
||||
else:
|
||||
if nlls:
|
||||
nll = sum(nlls) / len(nlls)
|
||||
if accs:
|
||||
accs = sum(accs) / len(accs)
|
||||
|
||||
if nll is not None:
|
||||
if f'{name}.nll' not in loss:
|
||||
|
@ -1698,6 +1733,16 @@ class Base(nn.Module):
|
|||
if logits[batch_index].dim() < 3:
|
||||
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
|
||||
elif not self.resp_parallel_training:
|
||||
# cringe way to deduce "requested" level
|
||||
level = 0
|
||||
for i in range( self.n_resp_levels ):
|
||||
if classifier_level == f'NAR:{i}:{i}':
|
||||
level = i
|
||||
break
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
||||
else:
|
||||
nlls = []
|
||||
accs = []
|
||||
|
@ -1779,7 +1824,7 @@ class Base(nn.Module):
|
|||
# needs to be done here as we still have our raw inputs
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||
causal_levels = [ "AR:0:0", "stt", "len", "phn" ]
|
||||
causal_levels = [ "stt", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
|
||||
|
||||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||
|
@ -1800,7 +1845,7 @@ class Base(nn.Module):
|
|||
if self.version >= 7:
|
||||
logits = [ logit for logit in logits ]
|
||||
|
||||
audio_decoder_levels = [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
|
||||
audio_decoder_levels = [ f"AR:{i}:{i}" for i in range(self.n_resp_levels) ] + [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
|
||||
|
||||
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ]
|
||||
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ]
|
||||
|
@ -2157,7 +2202,7 @@ if __name__ == "__main__":
|
|||
if is_from_pretrained:
|
||||
n_vocab = end - start
|
||||
|
||||
embds[k] = torch.nn.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
|
||||
embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
|
||||
embds[k].weight[:] = model.embed_tokens.weight[start:end, :]
|
||||
|
||||
if classifier_idx >= 0:
|
||||
|
@ -2169,7 +2214,7 @@ if __name__ == "__main__":
|
|||
heads[k].weight[:] = hf_model.lm_head.weight[start:end, :]
|
||||
else:
|
||||
embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name]
|
||||
embds[k] = torch.nn.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
|
||||
embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
|
||||
embds[k].load_state_dict({ "weight": embd_weight })
|
||||
|
||||
if classifier_idx >= 0:
|
||||
|
|
|
@ -32,7 +32,7 @@ if cfg.optimizations.bitsandbytes:
|
|||
Linear = bnb.nn.Linear8bitLt
|
||||
|
||||
if cfg.optimizations.embedding:
|
||||
Embedding = bnb.nn.modules.Embedding
|
||||
Embedding = bnb.nn.StableEmbedding
|
||||
"""
|
||||
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
||||
input,
|
||||
|
|
Loading…
Reference in New Issue
Block a user