From ceecac6ffe30d53c41114a32b974cd9c87e218a3 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 26 Feb 2025 23:13:32 -0600 Subject: [PATCH] I think I made resp_parallel_training=True faster with loss factoring? --- vall_e/engines/base.py | 5 ++++- vall_e/models/base_v2.py | 42 +++++++++++----------------------------- 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 8ffc733..5b312ee 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -629,7 +629,10 @@ class Engines(dict[str, Engine]): if cfg.lora is not None: key_name = cfg.lora.full_name - stats.update(flatten_dict({key_name.split("-")[0]: model_stats})) + if len(self) == 1: + stats.update(flatten_dict(model_stats)) + else: + stats.update(flatten_dict({key_name.split("-")[0]: model_stats})) self._update() diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index e8cca54..a8ee428 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -748,7 +748,7 @@ class Base_V2(nn.Module): # filter tokens that exceed the vocab size sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence ) # drop if all tokens are ignored - if all(sequence == self.ignore_index): + if torch.all(sequence == self.ignore_index): return None, None # shift if causal @@ -757,8 +757,14 @@ class Base_V2(nn.Module): logit = logit[..., :-l, :] # shift the target so that token n... sequence = sequence[..., l:] # ...predicts token n + 1 + # flatten batch + if sequence.dim() > 1: + logit = logit.reshape(-1, logit.shape[-1]) + sequence = sequence.reshape(-1) + nll = None metrics = None + if compute_hard_loss: nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index ) @@ -868,41 +874,15 @@ class Base_V2(nn.Module): if classifier_level.endswith(f':{i}:{i}'): level = i break + """ if name == "resp": name = f'{name}[{level}]' + """ sequence = token if token.dim() <= 1 else token[:, level] nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal ) else: - nlls = [] - accs = [] - - for level, logit in enumerate( logits[batch_index] ): - sequence = token if token.dim() <= 1 else token[:, level] - nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal ) - - if name == "resp": - if nll is not None: - if f'{name}[{level}].nll' not in loss: - loss[f'{name}[{level}].nll'] = [] - loss[f"{name}[{level}].nll"].append( nll * loss_factor ) - - if metrics is not None: - if f'{name}[{level}].acc' not in stats: - stats[f'{name}[{level}].acc'] = [] - stats[f"{name}[{level}].acc"].append( metrics ) - - nll = None - metrics = None - else: - if nll: - nlls.append( nll ) - if metrics: - accs.append( metrics ) - if nlls: - nll = sum(nlls) / len(nlls) - if accs: - accs = sum(accs) / len(accs) - + sequence = token.t() + nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) if nll is not None: if f'{name}.nll' not in loss: loss[f'{name}.nll'] = []