repaired auraloss loss calc during eval/val

This commit is contained in:
mrq 2023-08-18 21:19:47 -05:00
parent fb4e816823
commit 508677fcd5
2 changed files with 30 additions and 39 deletions

View File

@ -1,49 +1,49 @@
dataset:
training: [
]
training: []
validation: [
]
validation: []
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
use_hdf5: True
hdf5_flag: r
validate: True
workers: 8
workers: 4
cache: True
phones_range: [4, 128]
duration_range: [1.0, 12.0]
phones_range: [4, 256]
duration_range: [1.0, 16.0]
random_utterance: 1.0
max_prompts: 3
prompt_duration: 3.0
sample_type: speaker
models:
_models:
- name: "ar"
size: "full"
resp_levels: 1
arch_type: "retnet"
prom_levels: 2
tasks: 8
arch_type: "retnet"
- name: "nar"
size: "full"
resp_levels: 1
arch_type: "retnet"
prom_levels: 2
tasks: 8
arch_type: "retnet"
hyperparameters:
batch_size: 8
gradient_accumulation_steps: 16
batch_size: 16
gradient_accumulation_steps: 2
gradient_clipping: 100
optimizer: Adamw
learning_rate: 1.0e-4
learning_rate: 1.0e-5
scheduler_type: ""
#scheduler_type: OneCycle
@ -66,40 +66,42 @@ hyperparameters:
evaluation:
batch_size: 32
frequency: 250
frequency: 500
size: 32
steps: 300
ar_temperature: 1.0
nar_temperature: 0.2
nar_temperature: 1.0
trainer:
iterations: 100_000
iterations: 1_000_000
save_tag: step
save_on_oom: True
save_on_quit: True
save_frequency: 100
save_frequency: 1000
keep_last_checkpoints: 4
aggressive_optimizations: False
#load_tag: "9500"
#load_state_dict: True
#load_states: False
#strict_loading: False
#load_tag: "9500"
#load_states: False
#restart_step_count: True
gc_mode: None # "global_step"
weight_dtype: bfloat16 # float16, float32
weight_dtype: bfloat16
backend: deepspeed
deepspeed:
zero_optimization_level: 2
zero_optimization_level: 0
use_compression_training: True
inference:
use_vocos: True
bitsandbytes:
enabled: false
enabled: false

View File

@ -20,15 +20,6 @@ from collections import defaultdict
from tqdm import tqdm
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
def center_crop(x, len):
start = (x.shape[-1] - len) // 2
stop = start + len
return x[..., start:stop]
def left_crop(x, len):
return x[..., 0:len]
_logger = logging.getLogger(__name__)
def train_feeder(engine, batch):
@ -87,17 +78,18 @@ def run_eval(engines, eval_name, dl):
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = center_crop(ref_audio, min_length) #ref_audio[..., 0:min_length]
hyp_audio = center_crop(hyp_audio, min_length) #hyp_audio[..., 0:min_length]
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
try:
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
except Exception as e:
stats['loss'].append(0)
print(traceback.format_exc())
processed = 0
for batch in tqdm(dl):
batch: dict = to_device(batch, cfg.device)
while processed < cfg.evaluation.size:
batch: dict = to_device(next(iter(dl)), cfg.device)
processed += len(batch["text"])
# if we're training both models, provide output for both
if AR is not None and NAR is not None:
@ -132,9 +124,6 @@ def run_eval(engines, eval_name, dl):
process( name, batch, resps_list )
processed += len(batch["text"])
if processed >= cfg.evaluation.size:
break
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats.update(flatten_dict({ name: stats }))