should never have trusted mse_loss, it never works
This commit is contained in:
parent
99f251c768
commit
a1184586ef
|
@ -297,6 +297,7 @@ class ModelExperimentalSettings:
|
||||||
# * NAR-demask would semi-doubly train for AR
|
# * NAR-demask would semi-doubly train for AR
|
||||||
# * the model wouldn't also need to learn when to predict the token in place
|
# * the model wouldn't also need to learn when to predict the token in place
|
||||||
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
||||||
|
len_use_logits: bool = False # whether to treat duration prediction as a nll/logits task or use a raw, continuous float
|
||||||
len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16
|
len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16
|
||||||
parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking
|
parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking
|
||||||
layer_dropout_p: float = 0.0 # performs layer dropout, which I readded because it might actually help since the reference model had this at 0.1
|
layer_dropout_p: float = 0.0 # performs layer dropout, which I readded because it might actually help since the reference model had this at 0.1
|
||||||
|
|
|
@ -959,8 +959,10 @@ def example_usage():
|
||||||
phns_list, proms_list, resp_list, task_list = sample_data( task )
|
phns_list, proms_list, resp_list, task_list = sample_data( task )
|
||||||
|
|
||||||
if task == "tts-nar":
|
if task == "tts-nar":
|
||||||
# len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
|
len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
|
||||||
|
print( len_list )
|
||||||
len_list = [ r.shape[0] for r in resp_list ]
|
len_list = [ r.shape[0] for r in resp_list ]
|
||||||
|
print( len_list )
|
||||||
resps_list = engine( phns_list=phns_list, proms_list=proms_list, len_list=len_list )
|
resps_list = engine( phns_list=phns_list, proms_list=proms_list, len_list=len_list )
|
||||||
else:
|
else:
|
||||||
resps_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
resps_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
||||||
|
|
|
@ -293,6 +293,7 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
|
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
|
||||||
len_parallel_training = config.experimental.len_parallel_training if config is not None else False
|
len_parallel_training = config.experimental.len_parallel_training if config is not None else False
|
||||||
|
len_use_logits = config.experimental.len_use_logits if config is not None else True
|
||||||
predict_causally = config.experimental.predict_causally if config is not None else False
|
predict_causally = config.experimental.predict_causally if config is not None else False
|
||||||
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
|
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
|
||||||
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
|
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
|
||||||
|
@ -384,6 +385,7 @@ class Base_V2(nn.Module):
|
||||||
self.predict_causally = predict_causally
|
self.predict_causally = predict_causally
|
||||||
self.resp_parallel_training = resp_parallel_training
|
self.resp_parallel_training = resp_parallel_training
|
||||||
self.len_parallel_training = len_parallel_training
|
self.len_parallel_training = len_parallel_training
|
||||||
|
self.len_use_logits = len_use_logits
|
||||||
self.unified_position_ids = unified_position_ids
|
self.unified_position_ids = unified_position_ids
|
||||||
self.inject_timestep_embedding = False # results in bad output
|
self.inject_timestep_embedding = False # results in bad output
|
||||||
self.masking_ratio = masking_ratio
|
self.masking_ratio = masking_ratio
|
||||||
|
@ -433,7 +435,7 @@ class Base_V2(nn.Module):
|
||||||
self.n_resp_levels,
|
self.n_resp_levels,
|
||||||
use_ln=per_level_normalization,
|
use_ln=per_level_normalization,
|
||||||
)
|
)
|
||||||
self.len_decoder = AuxDecoder( d_model, 1 )
|
self.len_decoder = AuxDecoder( d_model, 1 if not len_use_logits else (10 * 5) )
|
||||||
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
|
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
|
||||||
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
|
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
|
||||||
|
|
||||||
|
@ -870,6 +872,10 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
# handles "tokenizing" an integer
|
||||||
|
def tokenize_duration( seq_lens, device, dtype=torch.int64 ):
|
||||||
|
return torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in seq_lens], device=device, dtype=dtype)
|
||||||
|
|
||||||
k_lo, k_hi = 1, 20
|
k_lo, k_hi = 1, 20
|
||||||
level_loss_factors = self.audio_level_loss_factors
|
level_loss_factors = self.audio_level_loss_factors
|
||||||
|
|
||||||
|
@ -1069,11 +1075,13 @@ class Base_V2(nn.Module):
|
||||||
if logits_aux is not None:
|
if logits_aux is not None:
|
||||||
len_factor = self.len_loss_factor # 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses)
|
len_factor = self.len_loss_factor # 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses)
|
||||||
aux_loss_logit = torch.cat( logits_aux )
|
aux_loss_logit = torch.cat( logits_aux )
|
||||||
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
|
|
||||||
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
if self.len_use_logits:
|
||||||
|
aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).squeeze(0)
|
||||||
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
|
loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
||||||
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
|
else:
|
||||||
|
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
|
||||||
|
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
|
||||||
|
|
||||||
return LossStats(loss, stats)
|
return LossStats(loss, stats)
|
||||||
|
|
||||||
|
@ -1225,17 +1233,26 @@ class Base_V2(nn.Module):
|
||||||
# do duration prediction
|
# do duration prediction
|
||||||
logits_aux = self.len_decoder( output.logits )
|
logits_aux = self.len_decoder( output.logits )
|
||||||
# it's more accurate this way
|
# it's more accurate this way
|
||||||
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
logits_aux = [ logit[..., -1, :] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||||
|
# reshape logits
|
||||||
|
if self.len_use_logits:
|
||||||
|
# get tokens
|
||||||
|
logits_aux = [ logit.view(5, 10).argmax(dim=-1) for logit in logits_aux ]
|
||||||
|
# stitch
|
||||||
|
logits_aux = [ int("".join([ str(t.item()) for t in logit ])) / self.audio_frames_per_second for logit in logits_aux ]
|
||||||
|
|
||||||
logits = logits_aux
|
logits = logits_aux
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
else:
|
else:
|
||||||
# do duration prediction
|
# do duration prediction
|
||||||
if self.len_parallel_training:
|
if self.len_parallel_training:
|
||||||
logits_aux = self.len_decoder( output.logits )
|
logits_aux = self.len_decoder( output.logits )
|
||||||
# only keep the input
|
# only keep the input
|
||||||
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||||
|
|
||||||
|
# reshape logits
|
||||||
|
if self.len_use_logits:
|
||||||
|
logits_aux = [ logit.view(5, 10) for logit in logits_aux ]
|
||||||
else:
|
else:
|
||||||
logits_aux = None
|
logits_aux = None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user