267 lines
7.9 KiB
Python
267 lines
7.9 KiB
Python
|
# todo: clean this mess up
|
||
|
# todo: yank deepspeed dependent code out into its own thing
|
||
|
|
||
|
from .config import cfg
|
||
|
from .data import create_train_val_dataloader
|
||
|
from .emb import qnt
|
||
|
|
||
|
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||
|
from .utils import wrapper as ml
|
||
|
|
||
|
from .models import get_models
|
||
|
|
||
|
import auraloss
|
||
|
import deepspeed
|
||
|
import json
|
||
|
import logging
|
||
|
import random
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
import traceback
|
||
|
|
||
|
from collections import defaultdict
|
||
|
|
||
|
from deepspeed import comm as dist
|
||
|
from deepspeed import DeepSpeedConfig
|
||
|
from deepspeed.accelerator import get_accelerator
|
||
|
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda")
|
||
|
|
||
|
def center_crop(x, len):
|
||
|
start = (x.shape[-1] - len) // 2
|
||
|
stop = start + len
|
||
|
return x[..., start:stop]
|
||
|
|
||
|
def left_crop(x, len):
|
||
|
return x[..., 0:len]
|
||
|
|
||
|
_logger = logging.getLogger(__name__)
|
||
|
deepspeed._initialized_dist = False
|
||
|
|
||
|
def load_engines():
|
||
|
if not deepspeed._initialized_dist:
|
||
|
deepspeed._initialized_dist = True
|
||
|
deepspeed.init_distributed()
|
||
|
|
||
|
models = get_models(cfg.models.get())
|
||
|
engines = dict()
|
||
|
|
||
|
for name in models:
|
||
|
model = models[name]
|
||
|
|
||
|
optimizer = None
|
||
|
lr_scheduler = None
|
||
|
|
||
|
Adam = ml.Adam
|
||
|
AdamW = ml.AdamW
|
||
|
|
||
|
if cfg.hyperparameters.optimizer.lower() == "adamw-torch":
|
||
|
optimizer = AdamW(
|
||
|
model.parameters(),
|
||
|
lr=cfg.hyperparameters.learning_rate,
|
||
|
betas=(0.9, 0.96),
|
||
|
eps=1e-07,
|
||
|
weight_decay=0.01,
|
||
|
)
|
||
|
|
||
|
if cfg.trainer.load_state_dict:
|
||
|
load_dir = cfg.ckpt_dir / name / "fp32.pth"
|
||
|
model.load_state_dict(torch.load(load_dir))
|
||
|
|
||
|
ds_cfg=cfg.get_ds_cfg(model=model)
|
||
|
config_class=DeepSpeedConfig(ds_cfg)
|
||
|
engines[name] = trainer.Engine(
|
||
|
model=model,
|
||
|
config=ds_cfg,
|
||
|
config_class=config_class,
|
||
|
optimizer=optimizer,
|
||
|
lr_scheduler=lr_scheduler,
|
||
|
)
|
||
|
|
||
|
return trainer.load_engines(engines, cfg)
|
||
|
|
||
|
def main():
|
||
|
#dist.init_distributed(dist_backend=get_accelerator().communication_backend_name())
|
||
|
if not deepspeed._initialized_dist:
|
||
|
deepspeed._initialized_dist = True
|
||
|
deepspeed.init_distributed()
|
||
|
|
||
|
setup_logging(cfg.log_dir)
|
||
|
|
||
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||
|
|
||
|
def train_feeder(engines, batch, name):
|
||
|
stats = {}
|
||
|
model = engines[name]
|
||
|
if name.startswith("ar"):
|
||
|
_ = model(
|
||
|
text_list=batch["text"],
|
||
|
proms_list=batch["proms"],
|
||
|
resp_list=[r[..., 0] for r in batch["resps"]],
|
||
|
)
|
||
|
elif name.startswith("nar"):
|
||
|
_ = model(
|
||
|
text_list=batch["text"],
|
||
|
proms_list=batch["proms"],
|
||
|
resps_list=batch["resps"],
|
||
|
)
|
||
|
else:
|
||
|
raise NotImplementedError(name)
|
||
|
|
||
|
losses = model.gather_attribute("loss")
|
||
|
|
||
|
loss = torch.stack([*losses.values()]).sum()
|
||
|
|
||
|
stats = {}
|
||
|
stats |= {k: v.item() for k, v in losses.items()}
|
||
|
stats |= engines.gather_attribute("scalar")
|
||
|
|
||
|
return loss, stats
|
||
|
|
||
|
@torch.inference_mode()
|
||
|
def run_eval(engines, eval_name, dl):
|
||
|
engines_stats = {
|
||
|
'eval': eval_name
|
||
|
}
|
||
|
|
||
|
AR = None
|
||
|
NAR = None
|
||
|
|
||
|
names = []
|
||
|
for name in engines:
|
||
|
model = engines[name]
|
||
|
names.append(name)
|
||
|
if name[:2] == "ar":
|
||
|
AR = model
|
||
|
elif name[:3] == "nar":
|
||
|
NAR = model
|
||
|
|
||
|
stats = defaultdict(list)
|
||
|
stats['loss'] = []
|
||
|
|
||
|
for batch in tqdm(dl):
|
||
|
batch: dict = to_device(batch, cfg.device)
|
||
|
|
||
|
# if we're training both models, provide output for both
|
||
|
if AR is not None and NAR is not None:
|
||
|
name = "+".join(names)
|
||
|
|
||
|
resp_list = AR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.temperature)
|
||
|
resps_list = [ r.unsqueeze(-1) for r in resp_list ]
|
||
|
resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.temperature)
|
||
|
|
||
|
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
|
||
|
if len(hyp) == 0:
|
||
|
continue
|
||
|
|
||
|
filename = f'{speaker}_{path.parts[-1]}'
|
||
|
|
||
|
# to-do, refine the output dir to be sane-er
|
||
|
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
|
||
|
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
|
||
|
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
|
||
|
|
||
|
hyp_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
prom_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
||
|
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
||
|
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
||
|
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
||
|
|
||
|
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||
|
ref_audio = ref_audio[..., 0:min_length]
|
||
|
hyp_audio = hyp_audio[..., 0:min_length]
|
||
|
|
||
|
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
|
||
|
else:
|
||
|
for name in engines:
|
||
|
model = engines[name]
|
||
|
|
||
|
if name.startswith("ar"):
|
||
|
resp_list = model(
|
||
|
text_list=batch["text"],
|
||
|
proms_list=batch["proms"],
|
||
|
max_steps=cfg.evaluation.steps,
|
||
|
sampling_temperature=cfg.evaluation.temperature,
|
||
|
)
|
||
|
resps_list = [r.unsqueeze(-1) for r in resp_list]
|
||
|
elif name.startswith("nar"):
|
||
|
resps_list = model(
|
||
|
text_list=batch["text"],
|
||
|
proms_list=batch["proms"],
|
||
|
resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]],
|
||
|
sampling_temperature=cfg.evaluation.temperature,
|
||
|
)
|
||
|
else:
|
||
|
raise NotImplementedError(name)
|
||
|
|
||
|
losses = model.gather_attribute("loss")
|
||
|
|
||
|
batch_stats = {}
|
||
|
batch_stats |= {k: v.item() for k, v in losses.items()}
|
||
|
batch_stats |= engines.gather_attribute("scalar")
|
||
|
|
||
|
for k, v in batch_stats.items():
|
||
|
stats[k].append(v)
|
||
|
|
||
|
for speaker, path, ref, hyp, prom in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"]):
|
||
|
if len(hyp) == 0:
|
||
|
continue
|
||
|
|
||
|
filename = f'{speaker}_{path.parts[-1]}'
|
||
|
|
||
|
# to-do, refine the output dir to be sane-er
|
||
|
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
|
||
|
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
|
||
|
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
|
||
|
|
||
|
hyp_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
prom_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
||
|
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
||
|
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
||
|
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
||
|
|
||
|
# pseudo loss calculation since we don't get the logits during eval
|
||
|
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||
|
ref_audio = ref_audio[..., 0:min_length]
|
||
|
hyp_audio = hyp_audio[..., 0:min_length]
|
||
|
|
||
|
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
|
||
|
|
||
|
stats = {k: sum(v) / len(v) for k, v in stats.items()}
|
||
|
engines_stats.update(flatten_dict({ name: stats }))
|
||
|
|
||
|
iteration = engines.global_step
|
||
|
engines_stats['it'] = iteration
|
||
|
engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
||
|
|
||
|
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||
|
|
||
|
def eval_fn(engines):
|
||
|
try:
|
||
|
run_eval(engines, "subtrain", subtrain_dl)
|
||
|
run_eval(engines, "val", val_dl)
|
||
|
except Exception as e:
|
||
|
print("Error occurred while performing eval:", str(e))
|
||
|
print(traceback.format_exc())
|
||
|
|
||
|
qnt.unload_model()
|
||
|
do_gc()
|
||
|
|
||
|
qnt.unload_model()
|
||
|
|
||
|
trainer.train(
|
||
|
engines_loader=load_engines,
|
||
|
train_dl=train_dl,
|
||
|
train_feeder=train_feeder,
|
||
|
eval_fn=eval_fn,
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|