diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 8ffc733..5b312ee 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -629,7 +629,10 @@ class Engines(dict[str, Engine]): if cfg.lora is not None: key_name = cfg.lora.full_name - stats.update(flatten_dict({key_name.split("-")[0]: model_stats})) + if len(self) == 1: + stats.update(flatten_dict(model_stats)) + else: + stats.update(flatten_dict({key_name.split("-")[0]: model_stats})) self._update() diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index e8cca54..e5d19f2 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -868,41 +868,18 @@ class Base_V2(nn.Module): if classifier_level.endswith(f':{i}:{i}'): level = i break + """ if name == "resp": name = f'{name}[{level}]' + """ sequence = token if token.dim() <= 1 else token[:, level] nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal ) else: - nlls = [] - accs = [] + vocab_size = logits[batch_index].shape[-1] - for level, logit in enumerate( logits[batch_index] ): - sequence = token if token.dim() <= 1 else token[:, level] - nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal ) - - if name == "resp": - if nll is not None: - if f'{name}[{level}].nll' not in loss: - loss[f'{name}[{level}].nll'] = [] - loss[f"{name}[{level}].nll"].append( nll * loss_factor ) - - if metrics is not None: - if f'{name}[{level}].acc' not in stats: - stats[f'{name}[{level}].acc'] = [] - stats[f"{name}[{level}].acc"].append( metrics ) - - nll = None - metrics = None - else: - if nll: - nlls.append( nll ) - if metrics: - accs.append( metrics ) - if nlls: - nll = sum(nlls) / len(nlls) - if accs: - accs = sum(accs) / len(accs) - + logit = logits[batch_index][:, start:end].reshape(-1, vocab_size) + sequence = token.reshape(-1).long() + nll, metrics = _calc_loss( logit, sequence, causal ) if nll is not None: if f'{name}.nll' not in loss: loss[f'{name}.nll'] = []