what has science done

This commit is contained in:
mrq 2025-02-13 16:07:40 -06:00
parent cebacc2303
commit 54d65cf37d
3 changed files with 60 additions and 18 deletions

View File

@ -396,11 +396,12 @@ def load_engines(training=True, **model_kwargs):
if cfg.lora is not None:
key_name = cfg.lora.full_name
kwargs['id'] = 'job'
salt = "run"
kwargs['id'] = f'{key_name}-{salt}'
kwargs['resume'] = 'allow'
if world_size() > 1:
kwargs["group"] = "DDP"
kwargs['id'] = f'job-{global_rank()}'
kwargs['id'] = f'{key_name}-{salt}-{global_rank()}'
engine.wandb = wandb.init(project=key_name, **kwargs)

View File

@ -171,7 +171,7 @@ class AR_NAR(Base):
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
# only apply stop token for RVQ level 0
if quant_level <= 0 and timesteps[i] is None:
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None):
# append stop tokens for AR
if task not in text_task:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
@ -1444,7 +1444,7 @@ def example_usage():
"""
for task in available_tasks:
sample("final", task=task)
sample("final", task="tts-nar")
engines.quit()

View File

@ -735,13 +735,15 @@ class Base(nn.Module):
if self.version >= 6:
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
self.resp_parallel_training = True #
self.monolithic_audio_encoder = True #
if self.version >= 7:
pd_model = d_model // 4
pd_ffn = pd_model * 4
pd_heads = n_heads // 4
pd_layers = 1
if False:
if self.monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 1, # masked token
n_levels=self.n_resp_levels,
@ -763,7 +765,7 @@ class Base(nn.Module):
self.n_resp_levels,
d_model,
dict(
vocab_size=n_audio_tokens,
vocab_size=n_audio_tokens + 1,
hidden_size=pd_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=pd_ffn,
@ -1134,7 +1136,7 @@ class Base(nn.Module):
inputs[i].append( ( "resp", resps_list[i] ) )
if self.version >= 7:
classifier_level = f"NAR:{quant_level}:{quant_level}"
classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}"
inputs[i].append( ("classifier_level", classifier_level) )
# Audio length prediction task
@ -1562,7 +1564,7 @@ class Base(nn.Module):
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 10,
top_k = min(logit.shape[0], 10),
average="micro",
multidim_average="global",
ignore_index = -100
@ -1662,9 +1664,24 @@ class Base(nn.Module):
if loss_factor == 0.0:
continue
# cringe way to deduce "requested" level
level = quant_level
for i in range( self.n_resp_levels ):
if classifier_level == f'NAR:{i}:{i}':
level = i
break
if logits[batch_index].dim() < 3:
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
if name == "resp":
name = f'{name}[{quant_level}]'
elif not self.resp_parallel_training:
if name == "resp":
name = f'{name}[{level}]'
sequence = token if token.dim() <= 1 else token[:, level]
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
else:
nlls = []
accs = []
@ -1673,15 +1690,29 @@ class Base(nn.Module):
sequence = token if token.dim() <= 1 else token[:, level]
nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal )
if nll:
nlls.append( nll )
if metrics:
accs.append( metrics )
if name == "resp":
if nll is not None:
if f'{name}[{level}].nll' not in loss:
loss[f'{name}[{level}].nll'] = []
loss[f"{name}[{level}].nll"].append( nll * loss_factor )
if metrics is not None:
if f'{name}[{level}].acc' not in stats:
stats[f'{name}[{level}].acc'] = []
stats[f"{name}[{level}].acc"].append( metrics )
if nlls:
nll = sum(nlls) / len(nlls)
if accs:
accs = sum(accs) / len(accs)
nll = None
metrics = None
else:
if nll:
nlls.append( nll )
if metrics:
accs.append( metrics )
else:
if nlls:
nll = sum(nlls) / len(nlls)
if accs:
accs = sum(accs) / len(accs)
if nll is not None:
if f'{name}.nll' not in loss:
@ -1701,6 +1732,16 @@ class Base(nn.Module):
if logits[batch_index].dim() < 3:
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
elif not self.resp_parallel_training:
# cringe way to deduce "requested" level
level = 0
for i in range( self.n_resp_levels ):
if classifier_level == f'NAR:{i}:{i}':
level = i
break
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
else:
nlls = []
accs = []
@ -1782,7 +1823,7 @@ class Base(nn.Module):
# needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" )
causal_levels = [ "AR:0:0", "stt", "len", "phn" ]
causal_levels = [ "stt", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
# right now limit to new versions because I need to retrain the model for noncausal masks...
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
@ -1803,7 +1844,7 @@ class Base(nn.Module):
if self.version >= 7:
logits = [ logit for logit in logits ]
audio_decoder_levels = [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
audio_decoder_levels = [ f"AR:{i}:{i}" for i in range(self.n_resp_levels) ] + [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ]
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ]