This commit is contained in:
mrq 2023-03-26 11:05:50 +00:00
parent 41d47c7c2a
commit c89c648b4a

View File

@ -774,7 +774,8 @@ class TrainingState():
keys = { keys = {
'lrs': ['lr'], 'lrs': ['lr'],
'losses': ['loss_text_ce', 'loss_mel_ce'], 'losses': ['loss_text_ce', 'loss_mel_ce'],
'accuracy': [], 'accuracies': [],
'grad_norms': [],
} }
if args.tts_backend == "vall-e": if args.tts_backend == "vall-e":
keys['lrs'] = [ keys['lrs'] = [
@ -797,6 +798,7 @@ class TrainingState():
'ar-half.loss.acc', 'nar-half.loss.acc', 'ar-half.loss.acc', 'nar-half.loss.acc',
'ar-quarter.loss.acc', 'nar-quarter.loss.acc', 'ar-quarter.loss.acc', 'nar-quarter.loss.acc',
] ]
keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']
for k in keys['lrs']: for k in keys['lrs']:
if k not in self.info: if k not in self.info:
@ -823,7 +825,7 @@ class TrainingState():
self.losses.append( self.statistics['loss'][-1] ) self.losses.append( self.statistics['loss'][-1] )
for k in ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']: for k in keys['grad_norms']:
if k not in self.info: if k not in self.info:
continue continue
self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})