fixes #176
This commit is contained in:
parent
41d47c7c2a
commit
c89c648b4a
|
@ -774,7 +774,8 @@ class TrainingState():
|
|||
keys = {
|
||||
'lrs': ['lr'],
|
||||
'losses': ['loss_text_ce', 'loss_mel_ce'],
|
||||
'accuracy': [],
|
||||
'accuracies': [],
|
||||
'grad_norms': [],
|
||||
}
|
||||
if args.tts_backend == "vall-e":
|
||||
keys['lrs'] = [
|
||||
|
@ -797,6 +798,7 @@ class TrainingState():
|
|||
'ar-half.loss.acc', 'nar-half.loss.acc',
|
||||
'ar-quarter.loss.acc', 'nar-quarter.loss.acc',
|
||||
]
|
||||
keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']
|
||||
|
||||
for k in keys['lrs']:
|
||||
if k not in self.info:
|
||||
|
@ -823,7 +825,7 @@ class TrainingState():
|
|||
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
|
||||
for k in ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']:
|
||||
for k in keys['grad_norms']:
|
||||
if k not in self.info:
|
||||
continue
|
||||
self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||
|
|
Loading…
Reference in New Issue
Block a user