From 853c7fdccf13277e500307aa241e149d7e99ac7d Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 3 May 2023 21:31:37 +0000 Subject: [PATCH] forgot to uncomment the block to transcribe and slice when using transcribe all because I was piece-processing a huge batch of LibriTTS and somehow that leaked over to the repo --- src/utils.py | 19 ++++++++++++++++++- src/webui.py | 4 +--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/utils.py b/src/utils.py index 4eb4b85..8e5ea00 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1458,6 +1458,7 @@ class TrainingState(): 'lrs': ['lr'], 'losses': ['loss_text_ce', 'loss_mel_ce'], 'accuracies': [], + 'precisions': [], 'grad_norms': [], } if args.tts_backend == "vall-e": @@ -1481,6 +1482,11 @@ class TrainingState(): 'ar-half.loss.acc', 'nar-half.loss.acc', 'ar-quarter.loss.acc', 'nar-quarter.loss.acc', ] + keys['precisions'] = [ + 'ar.loss.precision', 'nar.loss.precision', + 'ar-half.loss.precision', 'nar-half.loss.precision', + 'ar-quarter.loss.precision', 'nar-quarter.loss.precision', + ] keys['grad_norms'] = ['ar.grad_norm', 'nar.grad_norm', 'ar-half.grad_norm', 'nar-half.grad_norm', 'ar-quarter.grad_norm', 'nar-quarter.grad_norm'] for k in keys['lrs']: @@ -1494,6 +1500,12 @@ class TrainingState(): continue self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) + + for k in keys['precisions']: + if k not in self.info: + continue + + self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) for k in keys['losses']: if k not in self.info: @@ -1671,7 +1683,10 @@ class TrainingState(): for k in data: if data[k] is None: continue - averager['metrics'][k].append( data[k] ) + if k not in averager['metrics']: + averager['metrics'][k] = [ data[k] ] + else: + averager['metrics'][k].append( data[k] ) unq[f'{it}_{mode}_{name}'] = averager else: @@ -1685,6 +1700,8 @@ class TrainingState(): if args.tts_backend == "vall-e": stats = unq[it] data = {k: sum(v) / len(v) for k, v in stats['metrics'].items() if k not in blacklist } + #data = {k: min(v) for k, v in stats['metrics'].items() if k not in blacklist } + #data = {k: max(v) for k, v in stats['metrics'].items() if k not in blacklist } data['name'] = stats['name'] data['mode'] = stats['mode'] data['steps'] = len(stats['metrics']['it']) diff --git a/src/webui.py b/src/webui.py index f2a5ba6..e079675 100755 --- a/src/webui.py +++ b/src/webui.py @@ -221,7 +221,6 @@ def prepare_all_datasets( language, validation_text_length, validation_audio_len messages = [] voices = get_voice_list() - """ for voice in voices: print("Processing:", voice) message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress ) @@ -232,8 +231,7 @@ def prepare_all_datasets( language, validation_text_length, validation_audio_len print("Processing:", voice) message = slice_dataset( voice, trim_silence=trim_silence, start_offset=slice_start_offset, end_offset=slice_end_offset, results=None, progress=progress ) messages.append(message) - """ - + for voice in voices: print("Processing:", voice) message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length, progress=progress )