diff --git a/src/utils.py b/src/utils.py index baf5dbc..d6f6d8f 100755 --- a/src/utils.py +++ b/src/utils.py @@ -774,7 +774,8 @@ class TrainingState(): keys = { 'lrs': ['lr'], 'losses': ['loss_text_ce', 'loss_mel_ce'], - 'accuracy': [], + 'accuracies': [], + 'grad_norms': [], } if args.tts_backend == "vall-e": keys['lrs'] = [ @@ -797,6 +798,7 @@ class TrainingState(): 'ar-half.loss.acc', 'nar-half.loss.acc', 'ar-quarter.loss.acc', 'nar-quarter.loss.acc', ] + keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm'] for k in keys['lrs']: if k not in self.info: @@ -823,7 +825,7 @@ class TrainingState(): self.losses.append( self.statistics['loss'][-1] ) - for k in ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']: + for k in keys['grad_norms']: if k not in self.info: continue self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})