tortoise-tts/tortoise_tts/train.py
2024-06-18 17:09:50 -05:00

276 lines
9.0 KiB
Python
Executable File

# todo: clean this mess up
from .config import cfg
from .data import create_train_val_dataloader
from .emb import mel
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .utils.distributed import is_global_leader
import auraloss
import json
import logging
import random
import torch
import torchaudio
import torch.nn.functional as F
import traceback
import shutil
from collections import defaultdict
from tqdm import tqdm
import argparse
from torch.nn.utils.rnn import pad_sequence
from .models.arch_utils import denormalize_tacotron_mel
from .models.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
_logger = logging.getLogger(__name__)
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
device = batch["text"][0].device
batch_size = len(batch["text"])
autoregressive_conds = torch.stack([ conds for conds in batch["conds_0"] ])
diffusion_conds = torch.stack([ conds for conds in batch["conds_1"] ])
autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ])
diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ])
text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = engine.module.stop_mel_token )
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths)
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
engine.tokens_processed += sum([ mel.shape[-1] for mel in batch["mel"] ])
return loss, stats
@torch.inference_mode()
def run_eval(engines, eval_name, dl):
stats = defaultdict(list)
stats['loss'] = []
def process( name, batch, resps_list ):
for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]):
if len(hyp) == 0:
continue
filename = f'{speaker}_{path.parts[-1]}'
if task != "tts":
filename = f"{filename}_{task}"
# to-do, refine the output dir to be sane-er
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
hyp_path.parent.mkdir(parents=True, exist_ok=True)
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
ref_audio, sr = emb.decode_to_file(ref, ref_path)
hyp_audio, sr = emb.decode_to_file(hyp, hyp_path)
prom_audio, sr = emb.decode_to_file(prom, prom_path)
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
autoregressive = None
diffusion = None
clvp = None
vocoder = None
for name in engines:
engine = engines[name]
if "autoregressive" in name:
autoregressive = engine.module
elif "diffusion" in name:
diffusion = engine.module
elif "clvp" in name:
clvp = engine.module
elif "vocoder" in name:
vocoder = engine.module
trained_diffusion_steps=4000
desired_diffusion_steps=50
cond_free=False
cond_free_k=1
diffuser = SpacedDiffusion(
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
model_mean_type='epsilon',
model_var_type='learned_range',
loss_type='mse',
betas=get_named_beta_schedule('linear', trained_diffusion_steps),
conditioning_free=cond_free,
conditioning_free_k=cond_free_k
)
processed = 0
temperature = 1.0
while processed < cfg.evaluation.size:
batch: dict = to_device(next(iter(dl)), cfg.device)
processed += len(batch["text"])
max_mel_tokens = 500
stop_mel_token = autoregressive.stop_mel_token
calm_token = 83
verbose = True
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
autoregressive_conds = torch.stack([ conds for conds in batch["conds_0"] ])
diffusion_conds = torch.stack([ conds for conds in batch["conds_1"] ])
autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ])
diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ])
text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token )
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
# autoregressive pass
if True:
codes = autoregressive.inference_speech(
autoregressive_latents,
text_tokens,
do_sample=True,
#top_p=top_p,
temperature=temperature,
num_return_sequences=1,
#length_penalty=length_penalty,
#repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens,
)
padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
else:
codes = mel_codes
latents = autoregressive.forward(
autoregressive_latents,
text_tokens,
text_lengths,
codes,
wav_lengths,
return_latent=True,
clip_inputs=False
)
calm_tokens = 0
for k in range( codes.shape[-1] ):
if codes[0, k] == calm_token:
calm_tokens += 1
else:
calm_tokens = 0
if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k]
break
# diffusion pass
output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False)
noise = torch.randn(output_shape, device=latents.device) * temperature
mel = diffuser.p_sample_loop(
diffusion,
output_shape,
noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
progress=verbose
)
mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
# vocoder pass
wavs = vocoder.inference(mels)
for i, wav in enumerate( wavs ):
torchaudio.save( f"./data/{cfg.start_time}[{i}].wav", wav.cpu(), 24_000 )
# process( name, batch, resps_list )
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats = {
f'{name}.{eval_name}': stats,
"it": engines.global_step,
}
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
if cfg.trainer.no_logger:
tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.")
else:
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def train():
parser = argparse.ArgumentParser("TorToiSe TTS")
parser.add_argument("--eval", action="store_true", default=None)
args, unknown = parser.parse_known_args()
# create log folder
setup_logging(cfg.log_dir)
# copy config yaml to backup
if cfg.yaml_path is not None and is_global_leader():
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
def eval_fn(engines):
do_gc()
engines.eval()
# wrapped in a try block because it's sometimes prone to breaking
try:
run_eval(engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl)
except Exception as e:
print("Error occurred while performing eval:", str(e))
print(traceback.format_exc())
engines.train()
#qnt.unload_model()
do_gc()
#qnt.unload_model()
if args.eval:
return eval_fn(engines=trainer.load_engines())
"""
if cfg.trainer.load_webui:
from .webui import start
start(lock=False)
"""
trainer.train(
train_dl=train_dl,
train_feeder=train_feeder,
eval_fn=eval_fn,
)
if __name__ == "__main__":
# to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
train()