should never have trusted mse_loss, it never works

This commit is contained in:
mrq 2025-03-31 20:59:13 -05:00
parent 99f251c768
commit a1184586ef
3 changed files with 30 additions and 10 deletions

View File

@ -297,6 +297,7 @@ class ModelExperimentalSettings:
# * NAR-demask would semi-doubly train for AR
# * the model wouldn't also need to learn when to predict the token in place
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
len_use_logits: bool = False # whether to treat duration prediction as a nll/logits task or use a raw, continuous float
len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16
parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking
layer_dropout_p: float = 0.0 # performs layer dropout, which I readded because it might actually help since the reference model had this at 0.1

View File

@ -959,8 +959,10 @@ def example_usage():
phns_list, proms_list, resp_list, task_list = sample_data( task )
if task == "tts-nar":
# len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
print( len_list )
len_list = [ r.shape[0] for r in resp_list ]
print( len_list )
resps_list = engine( phns_list=phns_list, proms_list=proms_list, len_list=len_list )
else:
resps_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )

View File

@ -293,6 +293,7 @@ class Base_V2(nn.Module):
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
len_parallel_training = config.experimental.len_parallel_training if config is not None else False
len_use_logits = config.experimental.len_use_logits if config is not None else True
predict_causally = config.experimental.predict_causally if config is not None else False
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
@ -384,6 +385,7 @@ class Base_V2(nn.Module):
self.predict_causally = predict_causally
self.resp_parallel_training = resp_parallel_training
self.len_parallel_training = len_parallel_training
self.len_use_logits = len_use_logits
self.unified_position_ids = unified_position_ids
self.inject_timestep_embedding = False # results in bad output
self.masking_ratio = masking_ratio
@ -433,7 +435,7 @@ class Base_V2(nn.Module):
self.n_resp_levels,
use_ln=per_level_normalization,
)
self.len_decoder = AuxDecoder( d_model, 1 )
self.len_decoder = AuxDecoder( d_model, 1 if not len_use_logits else (10 * 5) )
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
@ -870,6 +872,10 @@ class Base_V2(nn.Module):
return input
# handles "tokenizing" an integer
def tokenize_duration( seq_lens, device, dtype=torch.int64 ):
return torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in seq_lens], device=device, dtype=dtype)
k_lo, k_hi = 1, 20
level_loss_factors = self.audio_level_loss_factors
@ -1069,11 +1075,13 @@ class Base_V2(nn.Module):
if logits_aux is not None:
len_factor = self.len_loss_factor # 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses)
aux_loss_logit = torch.cat( logits_aux )
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
if self.len_use_logits:
aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).squeeze(0)
loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
else:
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
return LossStats(loss, stats)
@ -1225,17 +1233,26 @@ class Base_V2(nn.Module):
# do duration prediction
logits_aux = self.len_decoder( output.logits )
# it's more accurate this way
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
logits_aux = [ logit[..., -1, :] for logit, aux_len in zip(logits_aux, aux_lens) ]
# reshape logits
if self.len_use_logits:
# get tokens
logits_aux = [ logit.view(5, 10).argmax(dim=-1) for logit in logits_aux ]
# stitch
logits_aux = [ int("".join([ str(t.item()) for t in logit ])) / self.audio_frames_per_second for logit in logits_aux ]
logits = logits_aux
# compute loss if the target is given
else:
# do duration prediction
if self.len_parallel_training:
logits_aux = self.len_decoder( output.logits )
# only keep the input
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :] for logit, aux_len in zip(logits_aux, aux_lens) ]
# reshape logits
if self.len_use_logits:
logits_aux = [ logit.view(5, 10) for logit in logits_aux ]
else:
logits_aux = None