forked from mrq/ai-voice-cloning
fixes #176
This commit is contained in:
parent
41d47c7c2a
commit
c89c648b4a
|
@ -774,7 +774,8 @@ class TrainingState():
|
||||||
keys = {
|
keys = {
|
||||||
'lrs': ['lr'],
|
'lrs': ['lr'],
|
||||||
'losses': ['loss_text_ce', 'loss_mel_ce'],
|
'losses': ['loss_text_ce', 'loss_mel_ce'],
|
||||||
'accuracy': [],
|
'accuracies': [],
|
||||||
|
'grad_norms': [],
|
||||||
}
|
}
|
||||||
if args.tts_backend == "vall-e":
|
if args.tts_backend == "vall-e":
|
||||||
keys['lrs'] = [
|
keys['lrs'] = [
|
||||||
|
@ -797,6 +798,7 @@ class TrainingState():
|
||||||
'ar-half.loss.acc', 'nar-half.loss.acc',
|
'ar-half.loss.acc', 'nar-half.loss.acc',
|
||||||
'ar-quarter.loss.acc', 'nar-quarter.loss.acc',
|
'ar-quarter.loss.acc', 'nar-quarter.loss.acc',
|
||||||
]
|
]
|
||||||
|
keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']
|
||||||
|
|
||||||
for k in keys['lrs']:
|
for k in keys['lrs']:
|
||||||
if k not in self.info:
|
if k not in self.info:
|
||||||
|
@ -823,7 +825,7 @@ class TrainingState():
|
||||||
|
|
||||||
self.losses.append( self.statistics['loss'][-1] )
|
self.losses.append( self.statistics['loss'][-1] )
|
||||||
|
|
||||||
for k in ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm']:
|
for k in keys['grad_norms']:
|
||||||
if k not in self.info:
|
if k not in self.info:
|
||||||
continue
|
continue
|
||||||
self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
self.statistics['grad_norm'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||||
|
|
Loading…
Reference in New Issue
Block a user