what has science done
This commit is contained in:
parent
cebacc2303
commit
54d65cf37d
|
@ -396,11 +396,12 @@ def load_engines(training=True, **model_kwargs):
|
|||
if cfg.lora is not None:
|
||||
key_name = cfg.lora.full_name
|
||||
|
||||
kwargs['id'] = 'job'
|
||||
salt = "run"
|
||||
kwargs['id'] = f'{key_name}-{salt}'
|
||||
kwargs['resume'] = 'allow'
|
||||
if world_size() > 1:
|
||||
kwargs["group"] = "DDP"
|
||||
kwargs['id'] = f'job-{global_rank()}'
|
||||
kwargs['id'] = f'{key_name}-{salt}-{global_rank()}'
|
||||
|
||||
|
||||
engine.wandb = wandb.init(project=key_name, **kwargs)
|
||||
|
|
|
@ -171,7 +171,7 @@ class AR_NAR(Base):
|
|||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_level <= 0 and timesteps[i] is None:
|
||||
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None):
|
||||
# append stop tokens for AR
|
||||
if task not in text_task:
|
||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||
|
@ -1444,7 +1444,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
for task in available_tasks:
|
||||
sample("final", task=task)
|
||||
sample("final", task="tts-nar")
|
||||
|
||||
engines.quit()
|
||||
|
||||
|
|
|
@ -735,13 +735,15 @@ class Base(nn.Module):
|
|||
if self.version >= 6:
|
||||
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
||||
|
||||
self.resp_parallel_training = True #
|
||||
self.monolithic_audio_encoder = True #
|
||||
if self.version >= 7:
|
||||
pd_model = d_model // 4
|
||||
pd_ffn = pd_model * 4
|
||||
pd_heads = n_heads // 4
|
||||
pd_layers = 1
|
||||
|
||||
if False:
|
||||
if self.monolithic_audio_encoder:
|
||||
self.audio_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 1, # masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
|
@ -763,7 +765,7 @@ class Base(nn.Module):
|
|||
self.n_resp_levels,
|
||||
d_model,
|
||||
dict(
|
||||
vocab_size=n_audio_tokens,
|
||||
vocab_size=n_audio_tokens + 1,
|
||||
hidden_size=pd_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=pd_ffn,
|
||||
|
@ -1134,7 +1136,7 @@ class Base(nn.Module):
|
|||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
if self.version >= 7:
|
||||
classifier_level = f"NAR:{quant_level}:{quant_level}"
|
||||
classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}"
|
||||
|
||||
inputs[i].append( ("classifier_level", classifier_level) )
|
||||
# Audio length prediction task
|
||||
|
@ -1562,7 +1564,7 @@ class Base(nn.Module):
|
|||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = 10,
|
||||
top_k = min(logit.shape[0], 10),
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
|
@ -1662,9 +1664,24 @@ class Base(nn.Module):
|
|||
|
||||
if loss_factor == 0.0:
|
||||
continue
|
||||
|
||||
# cringe way to deduce "requested" level
|
||||
level = quant_level
|
||||
for i in range( self.n_resp_levels ):
|
||||
if classifier_level == f'NAR:{i}:{i}':
|
||||
level = i
|
||||
break
|
||||
|
||||
if logits[batch_index].dim() < 3:
|
||||
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||
|
||||
if name == "resp":
|
||||
name = f'{name}[{quant_level}]'
|
||||
elif not self.resp_parallel_training:
|
||||
if name == "resp":
|
||||
name = f'{name}[{level}]'
|
||||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
||||
else:
|
||||
nlls = []
|
||||
accs = []
|
||||
|
@ -1673,15 +1690,29 @@ class Base(nn.Module):
|
|||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal )
|
||||
|
||||
if nll:
|
||||
nlls.append( nll )
|
||||
if metrics:
|
||||
accs.append( metrics )
|
||||
if name == "resp":
|
||||
if nll is not None:
|
||||
if f'{name}[{level}].nll' not in loss:
|
||||
loss[f'{name}[{level}].nll'] = []
|
||||
loss[f"{name}[{level}].nll"].append( nll * loss_factor )
|
||||
|
||||
if metrics is not None:
|
||||
if f'{name}[{level}].acc' not in stats:
|
||||
stats[f'{name}[{level}].acc'] = []
|
||||
stats[f"{name}[{level}].acc"].append( metrics )
|
||||
|
||||
if nlls:
|
||||
nll = sum(nlls) / len(nlls)
|
||||
if accs:
|
||||
accs = sum(accs) / len(accs)
|
||||
nll = None
|
||||
metrics = None
|
||||
else:
|
||||
if nll:
|
||||
nlls.append( nll )
|
||||
if metrics:
|
||||
accs.append( metrics )
|
||||
else:
|
||||
if nlls:
|
||||
nll = sum(nlls) / len(nlls)
|
||||
if accs:
|
||||
accs = sum(accs) / len(accs)
|
||||
|
||||
if nll is not None:
|
||||
if f'{name}.nll' not in loss:
|
||||
|
@ -1701,6 +1732,16 @@ class Base(nn.Module):
|
|||
if logits[batch_index].dim() < 3:
|
||||
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
|
||||
elif not self.resp_parallel_training:
|
||||
# cringe way to deduce "requested" level
|
||||
level = 0
|
||||
for i in range( self.n_resp_levels ):
|
||||
if classifier_level == f'NAR:{i}:{i}':
|
||||
level = i
|
||||
break
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
||||
else:
|
||||
nlls = []
|
||||
accs = []
|
||||
|
@ -1782,7 +1823,7 @@ class Base(nn.Module):
|
|||
# needs to be done here as we still have our raw inputs
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||
causal_levels = [ "AR:0:0", "stt", "len", "phn" ]
|
||||
causal_levels = [ "stt", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
|
||||
|
||||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||
|
@ -1803,7 +1844,7 @@ class Base(nn.Module):
|
|||
if self.version >= 7:
|
||||
logits = [ logit for logit in logits ]
|
||||
|
||||
audio_decoder_levels = [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
|
||||
audio_decoder_levels = [ f"AR:{i}:{i}" for i in range(self.n_resp_levels) ] + [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
|
||||
|
||||
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ]
|
||||
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ]
|
||||
|
|
Loading…
Reference in New Issue
Block a user