@ -1458,6 +1458,7 @@ class TrainingState():
' lrs ' : [ ' lr ' ] ,
' losses ' : [ ' loss_text_ce ' , ' loss_mel_ce ' ] ,
' accuracies ' : [ ] ,
' precisions ' : [ ] ,
' grad_norms ' : [ ] ,
}
if args . tts_backend == " vall-e " :
@ -1481,6 +1482,11 @@ class TrainingState():
' ar-half.loss.acc ' , ' nar-half.loss.acc ' ,
' ar-quarter.loss.acc ' , ' nar-quarter.loss.acc ' ,
]
keys [ ' precisions ' ] = [
' ar.loss.precision ' , ' nar.loss.precision ' ,
' ar-half.loss.precision ' , ' nar-half.loss.precision ' ,
' ar-quarter.loss.precision ' , ' nar-quarter.loss.precision ' ,
]
keys [ ' grad_norms ' ] = [ ' ar.grad_norm ' , ' nar.grad_norm ' , ' ar-half.grad_norm ' , ' nar-half.grad_norm ' , ' ar-quarter.grad_norm ' , ' nar-quarter.grad_norm ' ]
for k in keys [ ' lrs ' ] :
@ -1494,6 +1500,12 @@ class TrainingState():
continue
self . statistics [ ' loss ' ] . append ( { ' epoch ' : epoch , ' it ' : self . it , ' value ' : self . info [ k ] , ' type ' : k } )
for k in keys [ ' precisions ' ] :
if k not in self . info :
continue
self . statistics [ ' loss ' ] . append ( { ' epoch ' : epoch , ' it ' : self . it , ' value ' : self . info [ k ] , ' type ' : k } )
for k in keys [ ' losses ' ] :
if k not in self . info :
@ -1671,7 +1683,10 @@ class TrainingState():
for k in data :
if data [ k ] is None :
continue
averager [ ' metrics ' ] [ k ] . append ( data [ k ] )
if k not in averager [ ' metrics ' ] :
averager [ ' metrics ' ] [ k ] = [ data [ k ] ]
else :
averager [ ' metrics ' ] [ k ] . append ( data [ k ] )
unq [ f ' { it } _ { mode } _ { name } ' ] = averager
else :
@ -1685,6 +1700,8 @@ class TrainingState():
if args . tts_backend == " vall-e " :
stats = unq [ it ]
data = { k : sum ( v ) / len ( v ) for k , v in stats [ ' metrics ' ] . items ( ) if k not in blacklist }
#data = {k: min(v) for k, v in stats['metrics'].items() if k not in blacklist }
#data = {k: max(v) for k, v in stats['metrics'].items() if k not in blacklist }
data [ ' name ' ] = stats [ ' name ' ]
data [ ' mode ' ] = stats [ ' mode ' ]
data [ ' steps ' ] = len ( stats [ ' metrics ' ] [ ' it ' ] )