2023-08-02 21:53:35 +00:00
|
|
|
# todo: clean this mess up
|
|
|
|
|
|
|
|
from .config import cfg
|
|
|
|
from .data import create_train_val_dataloader
|
|
|
|
from .emb import qnt
|
|
|
|
|
|
|
|
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
|
|
|
|
|
|
|
import auraloss
|
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
import random
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
2023-08-14 03:07:45 +00:00
|
|
|
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
|
2023-08-02 21:53:35 +00:00
|
|
|
_logger = logging.getLogger(__name__)
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
def train_feeder(engine, batch):
|
2023-09-02 17:23:40 +00:00
|
|
|
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
2023-09-02 01:58:29 +00:00
|
|
|
engine(
|
|
|
|
text_list=batch["text"],
|
|
|
|
proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level
|
2023-10-12 02:21:50 +00:00
|
|
|
resps_list=batch["resps"],
|
|
|
|
lang_list=batch["lang"],
|
2023-09-02 01:58:29 +00:00
|
|
|
)
|
2023-08-04 01:26:36 +00:00
|
|
|
|
2023-09-02 01:58:29 +00:00
|
|
|
losses = engine.gather_attribute("loss")
|
|
|
|
stat = engine.gather_attribute("stats")
|
2023-08-04 01:26:36 +00:00
|
|
|
|
2023-09-02 01:58:29 +00:00
|
|
|
loss = torch.stack([*losses.values()]).sum()
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
stats = {}
|
|
|
|
stats |= {k: v.item() for k, v in losses.items()}
|
2023-08-05 20:25:41 +00:00
|
|
|
stats |= {k: v.item() for k, v in stat.items()}
|
2023-08-04 01:26:36 +00:00
|
|
|
|
2023-08-28 16:02:45 +00:00
|
|
|
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
|
|
|
|
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
|
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
return loss, stats
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
2023-09-22 18:04:17 +00:00
|
|
|
def run_eval(engines, eval_name, dl):
|
2023-08-04 01:26:36 +00:00
|
|
|
AR = None
|
|
|
|
NAR = None
|
2023-09-08 06:03:24 +00:00
|
|
|
AR_NAR = None
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
names = []
|
|
|
|
for name, engine in engines.items():
|
2023-09-08 06:03:24 +00:00
|
|
|
if name[:6] == "ar+nar":
|
|
|
|
AR_NAR = engine
|
|
|
|
elif name[:2] == "ar":
|
2023-08-04 01:26:36 +00:00
|
|
|
AR = engine
|
|
|
|
elif name[:3] == "nar":
|
|
|
|
NAR = engine
|
2023-08-27 17:26:12 +00:00
|
|
|
else:
|
|
|
|
continue
|
|
|
|
names.append(name)
|
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
stats = defaultdict(list)
|
|
|
|
stats['loss'] = []
|
|
|
|
|
|
|
|
def process( name, batch, resps_list ):
|
2023-08-19 06:16:46 +00:00
|
|
|
for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]):
|
2023-08-04 01:26:36 +00:00
|
|
|
if len(hyp) == 0:
|
|
|
|
continue
|
|
|
|
|
|
|
|
filename = f'{speaker}_{path.parts[-1]}'
|
|
|
|
|
2023-08-19 06:16:46 +00:00
|
|
|
if task != "tts":
|
|
|
|
filename = f"{filename}_{task}"
|
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
# to-do, refine the output dir to be sane-er
|
|
|
|
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
|
|
|
|
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
|
|
|
|
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
|
|
|
|
|
|
|
|
hyp_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
prom_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
|
|
|
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
|
|
|
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
|
|
|
|
|
|
|
# pseudo loss calculation since we don't get the logits during eval
|
|
|
|
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
2023-08-19 02:19:47 +00:00
|
|
|
ref_audio = ref_audio[..., 0:min_length]
|
|
|
|
hyp_audio = hyp_audio[..., 0:min_length]
|
2023-09-16 00:08:44 +00:00
|
|
|
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
|
2023-08-04 01:26:36 +00:00
|
|
|
|
2023-08-17 23:56:37 +00:00
|
|
|
processed = 0
|
2023-08-19 02:19:47 +00:00
|
|
|
while processed < cfg.evaluation.size:
|
|
|
|
batch: dict = to_device(next(iter(dl)), cfg.device)
|
|
|
|
processed += len(batch["text"])
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
# if we're training both models, provide output for both
|
|
|
|
if AR is not None and NAR is not None:
|
|
|
|
name = "+".join(names)
|
|
|
|
|
2023-08-04 19:21:30 +00:00
|
|
|
resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
2023-08-04 01:26:36 +00:00
|
|
|
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
2023-08-04 19:21:30 +00:00
|
|
|
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
process( name, batch, resps_list )
|
|
|
|
else:
|
|
|
|
for name in engines:
|
|
|
|
model = engines[name]
|
|
|
|
|
2023-09-08 06:03:24 +00:00
|
|
|
if name.startswith("ar+nar"):
|
|
|
|
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
|
|
|
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
|
|
|
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
|
|
|
elif name.startswith("ar"):
|
2023-08-04 01:26:36 +00:00
|
|
|
resps_list = model(
|
|
|
|
text_list=batch["text"],
|
|
|
|
proms_list=batch["proms"],
|
|
|
|
max_steps=cfg.evaluation.steps,
|
|
|
|
sampling_temperature=cfg.evaluation.ar_temperature,
|
|
|
|
)
|
|
|
|
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
|
|
|
elif name.startswith("nar"):
|
|
|
|
resps_list = model(
|
|
|
|
text_list=batch["text"],
|
|
|
|
proms_list=batch["proms"],
|
|
|
|
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
|
|
|
|
sampling_temperature=cfg.evaluation.nar_temperature,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(name)
|
|
|
|
|
|
|
|
process( name, batch, resps_list )
|
|
|
|
|
2023-08-17 23:56:37 +00:00
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
2023-09-13 18:19:11 +00:00
|
|
|
engines_stats = {
|
|
|
|
f'{name}.{eval_name}': stats,
|
|
|
|
"it": engines.global_step,
|
|
|
|
}
|
2023-08-19 01:58:07 +00:00
|
|
|
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-10-21 14:55:38 +00:00
|
|
|
def train():
|
2023-08-03 03:57:10 +00:00
|
|
|
setup_logging(cfg.log_dir)
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
2023-08-04 01:26:36 +00:00
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
def eval_fn(engines):
|
|
|
|
try:
|
2023-09-22 18:04:17 +00:00
|
|
|
run_eval(engines, "subtrain", subtrain_dl)
|
|
|
|
run_eval(engines, "val", val_dl)
|
2023-08-02 21:53:35 +00:00
|
|
|
except Exception as e:
|
|
|
|
print("Error occurred while performing eval:", str(e))
|
|
|
|
print(traceback.format_exc())
|
|
|
|
|
|
|
|
qnt.unload_model()
|
|
|
|
do_gc()
|
|
|
|
|
|
|
|
qnt.unload_model()
|
|
|
|
|
2023-10-21 14:55:38 +00:00
|
|
|
"""
|
|
|
|
if cfg.trainer.load_webui:
|
|
|
|
from .webui import start
|
|
|
|
start(lock=False)
|
|
|
|
"""
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
trainer.train(
|
|
|
|
train_dl=train_dl,
|
|
|
|
train_feeder=train_feeder,
|
|
|
|
eval_fn=eval_fn,
|
|
|
|
)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2023-10-21 14:55:38 +00:00
|
|
|
train()
|