# todo: clean this mess up from .config import cfg from .data import create_train_val_dataloader from .emb import qnt from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .data import fold_inputs, unfold_outputs import auraloss import json import logging import random import torch import torch.nn.functional as F import traceback from collections import defaultdict from tqdm import tqdm import argparse _logger = logging.getLogger(__name__) mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu") def train_feeder(engine, batch): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): if engine.hyper_config.experimental: batch_size = len(batch["text"]) if cfg.model.interleave: quant_levels = None resps_list = [ resp for resp in batch["resps"] ] else: quant_levels = torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,)) resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ] input_ids, attention_mask = fold_inputs( text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, targ_list=batch["resps"], quant_levels=quant_levels, ) target_ids, target_attention_mask = fold_inputs( text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, targ_list=batch["resps"], quant_levels=quant_levels, ignore_index=-100 ) engine( input_ids=input_ids, labels=target_ids, ) else: engine( text_list=batch["text"], proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level resps_list=batch["resps"], lang_list=batch["lang"], ) losses = engine.gather_attribute("loss") stat = engine.gather_attribute("stats") loss = torch.stack([*losses.values()]).sum() stats = {} stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in stat.items()} engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ]) engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ]) return loss, stats @torch.inference_mode() def run_eval(engines, eval_name, dl): stats = defaultdict(list) stats['loss'] = [] def process( name, batch, resps_list ): for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]): if len(hyp) == 0: continue filename = f'{speaker}_{path.parts[-1]}' if task != "tts": filename = f"{filename}_{task}" # to-do, refine the output dir to be sane-er ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav") hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav") prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav") hyp_path.parent.mkdir(parents=True, exist_ok=True) ref_path.parent.mkdir(parents=True, exist_ok=True) prom_path.parent.mkdir(parents=True, exist_ok=True) ref_audio, sr = qnt.decode_to_file(ref, ref_path) hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path) prom_audio, sr = qnt.decode_to_file(prom, prom_path) # pseudo loss calculation since we don't get the logits during eval min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) ref_audio = ref_audio[..., 0:min_length] hyp_audio = hyp_audio[..., 0:min_length] stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item()) processed = 0 while processed < cfg.evaluation.size: batch: dict = to_device(next(iter(dl)), cfg.device) processed += len(batch["text"]) for name in engines: engine = engines[name] if engine.hyper_config.experimental: if cfg.model.interleave: input_ids, attention_mask = fold_inputs( text_list=batch["text"], prom_list=batch["proms"], ) output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False) resps_list = unfold_outputs( output )["resp_list"] else: steps = cfg.evaluation.steps resps_list = [ [] for _ in range(len(text_list)) ] for l in range(cfg.model.max_levels): quant_levels = [ [ l ] for _ in range(len(text_list)) ] input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True) min_length = 1 for batch in input_ids: min_length = max( min_length, batch.shape[0] + 1 ) output = model.generate( input_ids=input_ids, attention_mask=attention_mask, min_length=min_length, max_length=min_length+steps*(2 if l > 0 else 1), eos_token_id=3, do_sample=False ) unfolded = unfold_outputs( output, quant_levels=quant_levels ) if l == 0: steps = 0 for batch, resp in enumerate(unfolded["resp_list"]): length = resp.shape[-1] # store length if l == 0: steps = max( steps, length ) # pad else: resp = resp[:steps] if length < steps: resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ]) resps_list[batch].append( resp ) for i, resp in enumerate( resps_list ): resps_list[i] = torch.stack( resp ).t() else: if "ar" in engine.hyper_config.capabilities: resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) else: resps_list = [ resp[:, 0] for resp in batch["resps"] ] if "nar" in engine.hyper_config.capabilities: resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) process( name, batch, resps_list ) stats = {k: sum(v) / len(v) for k, v in stats.items()} engines_stats = { f'{name}.{eval_name}': stats, "it": engines.global_step, } #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) if cfg.trainer.no_logger: tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.") else: _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") def train(): parser = argparse.ArgumentParser("VALL-E TTS") parser.add_argument("--eval", action="store_true", default=None) args, unknown = parser.parse_known_args() setup_logging(cfg.log_dir) train_dl, subtrain_dl, val_dl = create_train_val_dataloader() def eval_fn(engines): do_gc() engines.eval() # wrapped in a try block because it's sometimes prone to breaking try: run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "val", val_dl) except Exception as e: print("Error occurred while performing eval:", str(e)) print(traceback.format_exc()) engines.train() qnt.unload_model() do_gc() qnt.unload_model() if args.eval: return eval_fn(engines=trainer.load_engines()) """ if cfg.trainer.load_webui: from .webui import start start(lock=False) """ trainer.train( train_dl=train_dl, train_feeder=train_feeder, eval_fn=eval_fn, ) if __name__ == "__main__": # to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"` train()