From b05a905b95a8ccff2f0af386681811dd6c68d01e Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 5 Jun 2024 21:02:05 -0500 Subject: [PATCH] ugh --- vall_e/engines/__init__.py | 4 ++++ vall_e/models/base.py | 13 ++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 6890c62..4fb7f45 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -142,6 +142,10 @@ def load_engines(training=True): for k in erase: del state[k] + # resize text embedding + if cfg.model.text_tokens != state["text_emb.weight"].shape[0]: + state["text_emb.weight"] = state["text_emb.weight"][:cfg.model.text_tokens] + model.load_state_dict(state, strict=cfg.trainer.strict_loading) hyper_config = model.hyper_config diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c71e1cc..0bc1445 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -703,8 +703,9 @@ class Base(nn.Module): target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) ) + batch_size = len(target_list) # modify only for the AR so it can properly behave like a transformer - for i in range(len(target_list)): + for i in range(batch_size): if quant_levels is not None and quant_levels[i] > 0: continue @@ -725,10 +726,10 @@ class Base(nn.Module): ) else: self.loss = dict( - nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / len(batch) + nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / batch_size ) self.stats = dict( - acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / len(batch) + acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size ) return @@ -745,6 +746,8 @@ class Base(nn.Module): self.stats = dict(acc = dict()) info = {} + batch_size = len( inputs ) + for i, batch in enumerate( inputs ): quant_level = quant_levels[i] if quant_levels is not None else None @@ -799,8 +802,8 @@ class Base(nn.Module): # probably consumes less memory due to not having to allocate memory # this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed) else: - self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / len(batch) - self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / len(batch) + self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size + self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size # accuracy sometimes breaks for mamba