vall-e/vall_e/train.py

175 lines
5.0 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
# todo: clean this mess up
from .config import cfg
from .data import create_train_val_dataloader
from .emb import qnt
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .utils.trainer import load_engines
2023-08-02 21:53:35 +00:00
import auraloss
import json
import logging
import random
import torch
import torch.nn.functional as F
import traceback
from collections import defaultdict
from tqdm import tqdm
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
2023-08-02 21:53:35 +00:00
def center_crop(x, len):
start = (x.shape[-1] - len) // 2
stop = start + len
return x[..., start:stop]
def left_crop(x, len):
return x[..., 0:len]
_logger = logging.getLogger(__name__)
2023-08-04 01:26:36 +00:00
def train_feeder(engine, batch):
engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] )
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
2023-08-04 01:26:36 +00:00
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}
2023-08-04 01:26:36 +00:00
return loss, stats
@torch.inference_mode()
def run_eval(engines, eval_name, dl):
engines_stats = {
'eval': eval_name
}
AR = None
NAR = None
names = []
for name, engine in engines.items():
names.append(name)
if name[:2] == "ar":
AR = engine
elif name[:3] == "nar":
NAR = engine
stats = defaultdict(list)
stats['loss'] = []
def process( name, batch, resps_list ):
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
if len(hyp) == 0:
continue
filename = f'{speaker}_{path.parts[-1]}'
# to-do, refine the output dir to be sane-er
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
hyp_path.parent.mkdir(parents=True, exist_ok=True)
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
2023-08-19 01:58:07 +00:00
ref_audio = center_crop(ref_audio, min_length) #ref_audio[..., 0:min_length]
hyp_audio = center_crop(hyp_audio, min_length) #hyp_audio[..., 0:min_length]
2023-08-04 01:26:36 +00:00
try:
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
except Exception as e:
stats['loss'].append(0)
2023-08-04 01:26:36 +00:00
print(str(e))
processed = 0
2023-08-04 01:26:36 +00:00
for batch in tqdm(dl):
batch: dict = to_device(batch, cfg.device)
# if we're training both models, provide output for both
if AR is not None and NAR is not None:
name = "+".join(names)
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
2023-08-04 01:26:36 +00:00
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
2023-08-04 01:26:36 +00:00
process( name, batch, resps_list )
else:
for name in engines:
model = engines[name]
if name.startswith("ar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
max_steps=cfg.evaluation.steps,
sampling_temperature=cfg.evaluation.ar_temperature,
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
elif name.startswith("nar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
sampling_temperature=cfg.evaluation.nar_temperature,
)
else:
raise NotImplementedError(name)
process( name, batch, resps_list )
processed += len(batch["text"])
2023-08-18 00:07:59 +00:00
if processed >= cfg.evaluation.size:
break
2023-08-04 01:26:36 +00:00
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats.update(flatten_dict({ name: stats }))
iteration = engines.global_step
engines_stats['it'] = iteration
2023-08-19 01:58:07 +00:00
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
2023-08-04 01:26:36 +00:00
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
2023-08-02 21:53:35 +00:00
def main():
setup_logging(cfg.log_dir)
2023-08-02 21:53:35 +00:00
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
def eval_fn(engines):
try:
run_eval(engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl)
except Exception as e:
print("Error occurred while performing eval:", str(e))
print(traceback.format_exc())
qnt.unload_model()
do_gc()
qnt.unload_model()
trainer.train(
train_dl=train_dl,
train_feeder=train_feeder,
eval_fn=eval_fn,
)
if __name__ == "__main__":
main()