should never have trusted mse_loss, it never works
This commit is contained in:
parent
99f251c768
commit
a1184586ef
|
@ -297,6 +297,7 @@ class ModelExperimentalSettings:
|
|||
# * NAR-demask would semi-doubly train for AR
|
||||
# * the model wouldn't also need to learn when to predict the token in place
|
||||
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
||||
len_use_logits: bool = False # whether to treat duration prediction as a nll/logits task or use a raw, continuous float
|
||||
len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16
|
||||
parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking
|
||||
layer_dropout_p: float = 0.0 # performs layer dropout, which I readded because it might actually help since the reference model had this at 0.1
|
||||
|
|
|
@ -959,8 +959,10 @@ def example_usage():
|
|||
phns_list, proms_list, resp_list, task_list = sample_data( task )
|
||||
|
||||
if task == "tts-nar":
|
||||
# len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
|
||||
len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
|
||||
print( len_list )
|
||||
len_list = [ r.shape[0] for r in resp_list ]
|
||||
print( len_list )
|
||||
resps_list = engine( phns_list=phns_list, proms_list=proms_list, len_list=len_list )
|
||||
else:
|
||||
resps_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
||||
|
|
|
@ -293,6 +293,7 @@ class Base_V2(nn.Module):
|
|||
|
||||
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
|
||||
len_parallel_training = config.experimental.len_parallel_training if config is not None else False
|
||||
len_use_logits = config.experimental.len_use_logits if config is not None else True
|
||||
predict_causally = config.experimental.predict_causally if config is not None else False
|
||||
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
|
||||
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
|
||||
|
@ -384,6 +385,7 @@ class Base_V2(nn.Module):
|
|||
self.predict_causally = predict_causally
|
||||
self.resp_parallel_training = resp_parallel_training
|
||||
self.len_parallel_training = len_parallel_training
|
||||
self.len_use_logits = len_use_logits
|
||||
self.unified_position_ids = unified_position_ids
|
||||
self.inject_timestep_embedding = False # results in bad output
|
||||
self.masking_ratio = masking_ratio
|
||||
|
@ -433,7 +435,7 @@ class Base_V2(nn.Module):
|
|||
self.n_resp_levels,
|
||||
use_ln=per_level_normalization,
|
||||
)
|
||||
self.len_decoder = AuxDecoder( d_model, 1 )
|
||||
self.len_decoder = AuxDecoder( d_model, 1 if not len_use_logits else (10 * 5) )
|
||||
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
|
||||
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
|
||||
|
||||
|
@ -870,6 +872,10 @@ class Base_V2(nn.Module):
|
|||
|
||||
return input
|
||||
|
||||
# handles "tokenizing" an integer
|
||||
def tokenize_duration( seq_lens, device, dtype=torch.int64 ):
|
||||
return torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in seq_lens], device=device, dtype=dtype)
|
||||
|
||||
k_lo, k_hi = 1, 20
|
||||
level_loss_factors = self.audio_level_loss_factors
|
||||
|
||||
|
@ -1069,11 +1075,13 @@ class Base_V2(nn.Module):
|
|||
if logits_aux is not None:
|
||||
len_factor = self.len_loss_factor # 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses)
|
||||
aux_loss_logit = torch.cat( logits_aux )
|
||||
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
|
||||
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
||||
|
||||
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
|
||||
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
|
||||
|
||||
if self.len_use_logits:
|
||||
aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).squeeze(0)
|
||||
loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
||||
else:
|
||||
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
|
||||
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
|
@ -1225,17 +1233,26 @@ class Base_V2(nn.Module):
|
|||
# do duration prediction
|
||||
logits_aux = self.len_decoder( output.logits )
|
||||
# it's more accurate this way
|
||||
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
logits_aux = [ logit[..., -1, :] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
# reshape logits
|
||||
if self.len_use_logits:
|
||||
# get tokens
|
||||
logits_aux = [ logit.view(5, 10).argmax(dim=-1) for logit in logits_aux ]
|
||||
# stitch
|
||||
logits_aux = [ int("".join([ str(t.item()) for t in logit ])) / self.audio_frames_per_second for logit in logits_aux ]
|
||||
|
||||
logits = logits_aux
|
||||
|
||||
# compute loss if the target is given
|
||||
else:
|
||||
# do duration prediction
|
||||
if self.len_parallel_training:
|
||||
logits_aux = self.len_decoder( output.logits )
|
||||
# only keep the input
|
||||
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
|
||||
# reshape logits
|
||||
if self.len_use_logits:
|
||||
logits_aux = [ logit.view(5, 10) for logit in logits_aux ]
|
||||
else:
|
||||
logits_aux = None
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user