diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index ad1eab7..7f42f42 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -1037,7 +1037,8 @@ def example_usage(): texts, proms, resps, tasks = sample_data() stats = {"step": i} - stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) + with torch.autograd.set_detect_anomaly(cfg.trainer.detect_grad_anomaly): + stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) stats |= {"grad_norm": engine.get_global_grad_norm()} tqdm.write(f"{stats}") diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index cbba2e9..9efe150 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -944,13 +944,8 @@ class Base_V2(nn.Module): sequence = token.t() nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) - for level in enumerate(self.n_resp_levels): - loss_key = f'{name}[{level}].nll' - if loss_key not in loss: - loss[loss_key] = [] - loss[loss_key].append( nll[level] * loss_factor ) - - nll = None + if nll is not None: + nll = nll.sum() loss_key = f'{name}.nll' acc_key = f'{name}.acc'