repaired auraloss loss calc during eval/val

This commit is contained in:
mrq 2023-08-18 21:19:47 -05:00
parent fb4e816823
commit 508677fcd5
2 changed files with 30 additions and 39 deletions

View File

@ -1,49 +1,49 @@
dataset: dataset:
training: [ training: []
]
validation: [ validation: []
]
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
use_hdf5: True use_hdf5: True
hdf5_flag: r
validate: True validate: True
workers: 8 workers: 4
cache: True cache: True
phones_range: [4, 128] phones_range: [4, 256]
duration_range: [1.0, 12.0] duration_range: [1.0, 16.0]
random_utterance: 1.0 random_utterance: 1.0
max_prompts: 3 max_prompts: 3
prompt_duration: 3.0 prompt_duration: 3.0
sample_type: speaker
models: models:
_models: _models:
- name: "ar" - name: "ar"
size: "full" size: "full"
resp_levels: 1 resp_levels: 1
arch_type: "retnet"
prom_levels: 2 prom_levels: 2
tasks: 8 tasks: 8
arch_type: "retnet"
- name: "nar" - name: "nar"
size: "full" size: "full"
resp_levels: 1 resp_levels: 1
arch_type: "retnet"
prom_levels: 2 prom_levels: 2
tasks: 8 tasks: 8
arch_type: "retnet"
hyperparameters: hyperparameters:
batch_size: 8 batch_size: 16
gradient_accumulation_steps: 16 gradient_accumulation_steps: 2
gradient_clipping: 100 gradient_clipping: 100
optimizer: Adamw optimizer: Adamw
learning_rate: 1.0e-4 learning_rate: 1.0e-5
scheduler_type: "" scheduler_type: ""
#scheduler_type: OneCycle #scheduler_type: OneCycle
@ -66,40 +66,42 @@ hyperparameters:
evaluation: evaluation:
batch_size: 32 batch_size: 32
frequency: 250 frequency: 500
size: 32 size: 32
steps: 300 steps: 300
ar_temperature: 1.0 ar_temperature: 1.0
nar_temperature: 0.2 nar_temperature: 1.0
trainer: trainer:
iterations: 100_000 iterations: 1_000_000
save_tag: step save_tag: step
save_on_oom: True save_on_oom: True
save_on_quit: True save_on_quit: True
save_frequency: 100 save_frequency: 1000
keep_last_checkpoints: 4
aggressive_optimizations: False aggressive_optimizations: False
#load_tag: "9500"
#load_state_dict: True #load_state_dict: True
#load_states: False
#strict_loading: False #strict_loading: False
#load_tag: "9500"
#load_states: False
#restart_step_count: True #restart_step_count: True
gc_mode: None # "global_step" gc_mode: None # "global_step"
weight_dtype: bfloat16 # float16, float32 weight_dtype: bfloat16
backend: deepspeed backend: deepspeed
deepspeed: deepspeed:
zero_optimization_level: 2 zero_optimization_level: 0
use_compression_training: True use_compression_training: True
inference: inference:
use_vocos: True use_vocos: True
bitsandbytes: bitsandbytes:
enabled: false enabled: false

View File

@ -20,15 +20,6 @@ from collections import defaultdict
from tqdm import tqdm from tqdm import tqdm
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu") mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
def center_crop(x, len):
start = (x.shape[-1] - len) // 2
stop = start + len
return x[..., start:stop]
def left_crop(x, len):
return x[..., 0:len]
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def train_feeder(engine, batch): def train_feeder(engine, batch):
@ -87,17 +78,18 @@ def run_eval(engines, eval_name, dl):
# pseudo loss calculation since we don't get the logits during eval # pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = center_crop(ref_audio, min_length) #ref_audio[..., 0:min_length] ref_audio = ref_audio[..., 0:min_length]
hyp_audio = center_crop(hyp_audio, min_length) #hyp_audio[..., 0:min_length] hyp_audio = hyp_audio[..., 0:min_length]
try: try:
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item()) stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
except Exception as e: except Exception as e:
stats['loss'].append(0) stats['loss'].append(0)
print(traceback.format_exc()) print(traceback.format_exc())
processed = 0 processed = 0
for batch in tqdm(dl): while processed < cfg.evaluation.size:
batch: dict = to_device(batch, cfg.device) batch: dict = to_device(next(iter(dl)), cfg.device)
processed += len(batch["text"])
# if we're training both models, provide output for both # if we're training both models, provide output for both
if AR is not None and NAR is not None: if AR is not None and NAR is not None:
@ -132,9 +124,6 @@ def run_eval(engines, eval_name, dl):
process( name, batch, resps_list ) process( name, batch, resps_list )
processed += len(batch["text"])
if processed >= cfg.evaluation.size:
break
stats = {k: sum(v) / len(v) for k, v in stats.items()} stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats.update(flatten_dict({ name: stats })) engines_stats.update(flatten_dict({ name: stats }))