# todo: clean this mess up from .config import cfg from .data import create_train_val_dataloader from .emb import mel from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc, wrapper as ml from .utils.distributed import is_global_leader import auraloss import json import logging import random import torch import torchaudio import torch.nn.functional as F import traceback import shutil from collections import defaultdict from tqdm import tqdm import argparse from torch.nn.utils.rnn import pad_sequence from .models.arch_utils import denormalize_tacotron_mel from .models.diffusion import get_diffuser from .models import load_model _logger = logging.getLogger(__name__) mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu") def train_feeder(engine, batch): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): device = batch["text"][0].device batch_size = len(batch["text"]) autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ]) diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ]) text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True) text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32) mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = engine.module.stop_mel_token ) wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32) engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths) losses = engine.gather_attribute("loss") stat = engine.gather_attribute("stats") loss = torch.stack([*losses.values()]).sum() stats = {} stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in stat.items()} engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ]) engine.tokens_processed += sum([ mel.shape[-1] for mel in batch["mel"] ]) return loss, stats @torch.inference_mode() def run_eval(engines, eval_name, dl): stats = defaultdict(list) stats['loss'] = [] autoregressive = None diffusion = None clvp = None vocoder = None diffuser = get_diffuser(steps=30, cond_free=False) for name in engines: engine = engines[name] if "autoregressive" in name: autoregressive = engine.module elif "diffusion" in name: diffusion = engine.module elif "clvp" in name: clvp = engine.module elif "vocoder" in name: vocoder = engine.module if autoregressive is None: autoregressive = load_model("autoregressive", device=cfg.device) if diffusion is None: diffusion = load_model("diffusion", device=cfg.device) if clvp is None: clvp = load_model("clvp", device=cfg.device) if vocoder is None: vocoder = load_model("vocoder", device=cfg.device) def generate( batch, generate_codes=True ): temperature = 1.0 max_mel_tokens = 500 # * autoregressive.mel_length_compression stop_mel_token = autoregressive.stop_mel_token calm_token = 83 verbose = False autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ]) diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ]) text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True) text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32) mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token ) wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32) mel_codes = autoregressive.set_mel_padding(mel_codes, wav_lengths) with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): # autoregressive pass if generate_codes: codes = autoregressive.inference_speech( autoregressive_latents, text_tokens, do_sample=True, top_p=0.8, temperature=temperature, num_return_sequences=1, length_penalty=1.0, repetition_penalty=2.0, max_generate_length=max_mel_tokens, ) padding_needed = max_mel_tokens - codes.shape[1] codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) else: codes = mel_codes for i, code in enumerate( codes ): stop_token_indices = (codes[i] == stop_mel_token).nonzero() if len(stop_token_indices) == 0: continue codes[i][stop_token_indices] = 83 stm = stop_token_indices.min().item() codes[i][stm:] = 83 if stm - 3 < codes[i].shape[0]: codes[i][-3] = 45 codes[i][-2] = 45 codes[i][-1] = 248 wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device) latents = autoregressive.forward( autoregressive_latents, text_tokens, text_lengths, codes, wav_lengths, return_latent=True, clip_inputs=False ) calm_tokens = 0 for k in range( codes.shape[-1] ): if codes[0, k] == calm_token: calm_tokens += 1 else: calm_tokens = 0 if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. latents = latents[:, :k] break # diffusion pass with ml.auto_unload(diffusion, enabled=True): output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. output_shape = (latents.shape[0], 100, output_seq_len) precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False) noise = torch.randn(output_shape, device=latents.device) * temperature mel = diffuser.p_sample_loop( diffusion, output_shape, noise=noise, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, progress=True ) mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len] # vocoder pass with ml.auto_unload(vocoder, enabled=True): wavs = vocoder.inference(mels) return wavs def process( name, batch, hyps, refs ): for speaker, path, ref_audio, hyp_audio in zip(batch["spkr_name"], batch["path"], refs, hyps): filename = f'{speaker}_{path.parts[-1]}' # to-do, refine the output dir to be sane-er ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav") hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav") prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav") hyp_path.parent.mkdir(parents=True, exist_ok=True) ref_path.parent.mkdir(parents=True, exist_ok=True) prom_path.parent.mkdir(parents=True, exist_ok=True) torchaudio.save( hyp_path, hyp_audio.cpu(), 24_000 ) torchaudio.save( ref_path, ref_audio.cpu(), 24_000 ) # pseudo loss calculation since we don't get the logits during eval min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) ref_audio = ref_audio[..., 0:min_length] hyp_audio = hyp_audio[..., 0:min_length] stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item()) processed = 0 while processed < cfg.evaluation.size: batch = to_device(next(iter(dl)), cfg.device) batch_size = len(batch["text"]) processed += batch_size hyp = generate( batch, generate_codes=True ) ref = generate( batch, generate_codes=False ) process( name, batch, hyp, ref ) stats = {k: sum(v) / len(v) for k, v in stats.items()} engines_stats = { f'{name}.{eval_name}': stats, "it": engines.global_step, } #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) if cfg.trainer.no_logger: tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.") else: _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") diffusion = diffusion.to("cpu") clvp = clvp.to("cpu") vocoder = vocoder.to("cpu") def train(): parser = argparse.ArgumentParser("TorToiSe TTS") parser.add_argument("--eval", action="store_true", default=None) args, unknown = parser.parse_known_args() # create log folder setup_logging(cfg.log_dir) # copy config yaml to backup if cfg.yaml_path is not None and is_global_leader(): shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" ) train_dl, subtrain_dl, val_dl = create_train_val_dataloader() def eval_fn(engines): do_gc() engines.eval() # wrapped in a try block because it's sometimes prone to breaking try: run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "val", val_dl) except Exception as e: print("Error occurred while performing eval:", str(e)) print(traceback.format_exc()) engines.train() #qnt.unload_model() do_gc() #qnt.unload_model() if args.eval: return eval_fn(engines=trainer.load_engines()) """ if cfg.trainer.load_webui: from .webui import start start(lock=False) """ trainer.train( train_dl=train_dl, train_feeder=train_feeder, eval_fn=eval_fn, ) if __name__ == "__main__": # to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"` train()