vall-e/vall_e/train.py

248 lines
8.0 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
# todo: clean this mess up
from .config import cfg
from .data import create_train_val_dataloader, get_random_prompt, tokenize
from .emb import qnt, g2p
2023-08-02 21:53:35 +00:00
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
2024-06-04 02:28:49 +00:00
from .data import fold_inputs, unfold_outputs
2024-06-09 16:39:43 +00:00
from .utils.distributed import is_global_leader
2023-08-02 21:53:35 +00:00
import auraloss
import json
import logging
import random
import torch
import torch.nn.functional as F
import traceback
import shutil
2023-08-02 21:53:35 +00:00
from collections import defaultdict
from tqdm import tqdm
import argparse
2023-08-02 21:53:35 +00:00
_logger = logging.getLogger(__name__)
2023-08-04 01:26:36 +00:00
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
2023-08-04 01:26:36 +00:00
def train_feeder(engine, batch):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
batch_size = len(batch["text"])
engine.current_batch_size = batch_size
engine(
text_list=batch["text"],
proms_list=batch["proms"],
resps_list=batch["resps"],
lang_list=batch["lang"],
tone_list=batch["tone"],
task_list=batch["task"],
training=True,
)
2023-08-04 01:26:36 +00:00
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
2023-08-04 01:26:36 +00:00
loss = torch.stack([*losses.values()]).sum()
2023-08-04 01:26:36 +00:00
2024-11-02 03:36:48 +00:00
if torch.isnan(loss).any():
return
2023-08-04 01:26:36 +00:00
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
2023-08-04 01:26:36 +00:00
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
2023-08-04 01:26:36 +00:00
return loss, stats
@torch.inference_mode()
def run_eval(engines, eval_name, dl, args=None):
2023-08-04 01:26:36 +00:00
stats = defaultdict(list)
stats['loss'] = []
if cfg.evaluation.size == 0:
return
2023-08-04 01:26:36 +00:00
def process( name, batch, resps_list ):
for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]):
2023-08-04 01:26:36 +00:00
if len(hyp) == 0:
continue
filename = f'{speaker}_{path.parts[-1]}'
if task != "tts":
filename = f"{filename}_{task}"
2024-07-19 14:16:37 +00:00
# flatten prom
if not isinstance(prom, torch.Tensor) and prom is not None:
2024-07-19 14:16:37 +00:00
prom = torch.concat([ p for p in prom if isinstance(p, torch.Tensor) ])
2023-08-04 01:26:36 +00:00
# to-do, refine the output dir to be sane-er
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
hyp_path.parent.mkdir(parents=True, exist_ok=True)
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
if ref is not None:
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
if prom is not None:
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
2023-08-04 01:26:36 +00:00
# naive loss calculation
# to-do: find a better way to calculate this / a better metric
if ref is not None:
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
2023-08-04 01:26:36 +00:00
processed = 0
while processed < cfg.evaluation.size:
batch = to_device(next(iter(dl)), cfg.device)
# limit to eval batch size in the event we somehow have a weird dataloader
for key in batch.keys():
batch[key] = batch[key][:cfg.evaluation.batch_size]
batch_size = len(batch["text"])
# to-do: eval for text tasks
has_stt = False
for i, task in enumerate( batch["task"] ):
# easier to just change it to a tts task than drop stt tasks from the batch
if task == "stt":
# has_stt = True
batch["task"][i] = "tts"
batch["proms"][i] = batch["resps"][i][:75*3, :]
# random prompts requested
if args and args.eval_random_text_prompts and eval_name == "subtrain":
for i, _ in enumerate(batch["text"]):
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
batch["resps"][i] = None
2023-08-04 01:26:36 +00:00
processed += batch_size
2024-06-04 02:28:49 +00:00
for name in engines:
engine = engines[name]
base_kwargs = dict(
text_list=batch["text"],
proms_list=batch["proms"],
lang_list=batch["lang"],
task_list=batch["task"],
)
if engine.hyper_config.experimental.hf:
resps_list = engine( **base_kwargs )
elif "len" in engine.hyper_config.capabilities:
2024-11-07 05:14:16 +00:00
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
max_steps = kwargs.pop("max_steps", 500)
kwargs["max_steps"] = 10
len_list = engine( **kwargs ) # don't need more than that
len_list = [ min( l, max_steps ) for l in len_list ]
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
resps_list = engine( **kwargs, len_list=len_list )
2024-06-04 02:28:49 +00:00
else:
if "ar" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
resps_list = engine( **kwargs )
2024-06-05 15:30:04 +00:00
else:
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
2024-06-05 15:30:04 +00:00
if "nar" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
resps_list = engine( **kwargs, resps_list=resps_list )
2023-08-04 01:26:36 +00:00
process( name, batch, resps_list )
# evaluate why it's so slow
if has_stt:
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
kwargs["text_list"] = None
kwargs["task_list"] = [ "stt" for _ in range(batch_size) ]
kwargs["proms_list"] = [ ["stt"] for _ in range(batch_size) ]
kwargs["resps_list"] = batch["resps"]
text_list = engine( **kwargs, max_steps=max_steps, sampling_temperature=0.0)
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
_logger.info(f"Validation Metrics (STT): {text_list}")
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
engines_stats = {
f'{name}.{eval_name}': stats,
"it": engines.global_step,
}
2023-08-19 01:58:07 +00:00
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
2023-08-04 01:26:36 +00:00
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
def train():
parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("--eval", action="store_true", default=None)
parser.add_argument("--eval-random-text-prompts", action="store_true", default=None)
#parser.add_argument("--eval-random-audio-prompts", action="store_true", default=None)
2024-06-04 03:35:55 +00:00
args, unknown = parser.parse_known_args()
# create log folder
setup_logging(cfg.log_dir)
# copy config yaml to backup
2024-06-09 16:39:43 +00:00
if cfg.yaml_path is not None and is_global_leader():
shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" )
# create dataloaders
2023-08-02 21:53:35 +00:00
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
# evaluation lambda
2023-08-02 21:53:35 +00:00
def eval_fn(engines):
2024-05-25 22:46:52 +00:00
do_gc()
engines.eval()
# wrapped in a try block because it's sometimes prone to breaking
2023-08-02 21:53:35 +00:00
try:
run_eval(engines, "subtrain", subtrain_dl, args)
run_eval(engines, "val", val_dl, args)
2023-08-02 21:53:35 +00:00
except Exception as e:
_logger.warning(f"Error occurred while performing eval: {str(e)}")
_logger.warning(traceback.format_exc())
2023-08-02 21:53:35 +00:00
2024-05-25 22:46:52 +00:00
engines.train()
2023-08-02 21:53:35 +00:00
qnt.unload_model()
do_gc()
# unload EnCodec if it's already loaded
2023-08-02 21:53:35 +00:00
qnt.unload_model()
# only eval is requested
if args.eval:
return eval_fn(engines=trainer.load_engines())
"""
# start web UI
if cfg.trainer.load_webui:
from .webui import start
start(lock=False)
"""
# pre-training config validation
if cfg.model.experimental.layerskip and cfg.trainer.weight_dtype == "float16":
_logger.warning(f"Training with LayerSkip enabled with float16 may result in frying the model if the loss scale gets too small (<=8K) or with too large of a de facto batch size (>512 samples).")
# train
2023-08-02 21:53:35 +00:00
trainer.train(
train_dl=train_dl,
train_feeder=train_feeder,
eval_fn=eval_fn,
)
if __name__ == "__main__":
# to-do: for DDP, spawn multiprocess instead of requiring `torchrun --nnodes=1 --nproc-per-node=4 -m vall_e.train yaml="./data/config.yaml"`
train()