From 508677fcd531f6738835f0065bc989dc2257b82e Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 18 Aug 2023 21:19:47 -0500 Subject: [PATCH] repaired auraloss loss calc during eval/val --- data/config.yaml | 46 ++++++++++++++++++++++++---------------------- vall_e/train.py | 23 ++++++----------------- 2 files changed, 30 insertions(+), 39 deletions(-) diff --git a/data/config.yaml b/data/config.yaml index 152ae61..1055cbc 100755 --- a/data/config.yaml +++ b/data/config.yaml @@ -1,49 +1,49 @@ dataset: - training: [ - ] + training: [] - validation: [ - ] + validation: [] speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" use_hdf5: True + hdf5_flag: r validate: True - workers: 8 + workers: 4 cache: True - phones_range: [4, 128] - duration_range: [1.0, 12.0] + phones_range: [4, 256] + duration_range: [1.0, 16.0] random_utterance: 1.0 max_prompts: 3 prompt_duration: 3.0 + sample_type: speaker + models: _models: - name: "ar" size: "full" resp_levels: 1 - arch_type: "retnet" prom_levels: 2 tasks: 8 + arch_type: "retnet" - name: "nar" size: "full" resp_levels: 1 - arch_type: "retnet" prom_levels: 2 tasks: 8 - + arch_type: "retnet" hyperparameters: - batch_size: 8 - gradient_accumulation_steps: 16 + batch_size: 16 + gradient_accumulation_steps: 2 gradient_clipping: 100 optimizer: Adamw - learning_rate: 1.0e-4 + learning_rate: 1.0e-5 scheduler_type: "" #scheduler_type: OneCycle @@ -66,40 +66,42 @@ hyperparameters: evaluation: batch_size: 32 - frequency: 250 + frequency: 500 size: 32 steps: 300 ar_temperature: 1.0 - nar_temperature: 0.2 + nar_temperature: 1.0 trainer: - iterations: 100_000 + iterations: 1_000_000 save_tag: step save_on_oom: True save_on_quit: True - save_frequency: 100 + save_frequency: 1000 + + keep_last_checkpoints: 4 aggressive_optimizations: False - #load_tag: "9500" #load_state_dict: True - #load_states: False #strict_loading: False + #load_tag: "9500" + #load_states: False #restart_step_count: True gc_mode: None # "global_step" - weight_dtype: bfloat16 # float16, float32 + weight_dtype: bfloat16 backend: deepspeed deepspeed: - zero_optimization_level: 2 + zero_optimization_level: 0 use_compression_training: True inference: use_vocos: True bitsandbytes: - enabled: false \ No newline at end of file + enabled: false diff --git a/vall_e/train.py b/vall_e/train.py index 2e984c7..f59c4ec 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -20,15 +20,6 @@ from collections import defaultdict from tqdm import tqdm mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu") - -def center_crop(x, len): - start = (x.shape[-1] - len) // 2 - stop = start + len - return x[..., start:stop] - -def left_crop(x, len): - return x[..., 0:len] - _logger = logging.getLogger(__name__) def train_feeder(engine, batch): @@ -87,17 +78,18 @@ def run_eval(engines, eval_name, dl): # pseudo loss calculation since we don't get the logits during eval min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) - ref_audio = center_crop(ref_audio, min_length) #ref_audio[..., 0:min_length] - hyp_audio = center_crop(hyp_audio, min_length) #hyp_audio[..., 0:min_length] + ref_audio = ref_audio[..., 0:min_length] + hyp_audio = hyp_audio[..., 0:min_length] try: - stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item()) + stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item()) except Exception as e: stats['loss'].append(0) print(traceback.format_exc()) processed = 0 - for batch in tqdm(dl): - batch: dict = to_device(batch, cfg.device) + while processed < cfg.evaluation.size: + batch: dict = to_device(next(iter(dl)), cfg.device) + processed += len(batch["text"]) # if we're training both models, provide output for both if AR is not None and NAR is not None: @@ -132,9 +124,6 @@ def run_eval(engines, eval_name, dl): process( name, batch, resps_list ) - processed += len(batch["text"]) - if processed >= cfg.evaluation.size: - break stats = {k: sum(v) / len(v) for k, v in stats.items()} engines_stats.update(flatten_dict({ name: stats }))