From 180a4eac1bd9b3c291fa20ecf6cdae641eb4d2c3 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 28 Feb 2025 01:04:24 -0600 Subject: [PATCH] ughh --- vall_e/models/base_v2.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 86d4960..a586103 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -376,7 +376,9 @@ class Base_V2(nn.Module): elif attention_backend == "fused_attn": self.l_padding = 128 - if self.arch_type in ["llama"]: + if self.arch_type in ["none"]: + self.model = None + elif self.arch_type in ["llama"]: self.model_config = LlamaConfig( vocab_size=n_vocab, hidden_size=d_model, @@ -423,8 +425,10 @@ class Base_V2(nn.Module): attentions = None hidden_states = None + if self.arch_type in ["none"] or self.model is None: + ... # HF transformer derived model - if self.arch_type in ["llama"]: + elif self.arch_type in ["llama"]: kwargs = dict( inputs_embeds=x, attention_mask=m, @@ -988,7 +992,7 @@ class Base_V2(nn.Module): nll, metrics = _calc_loss( logit, sequence, causal, level ) nlls.append( nll ) - accs.append( accs ) + accs.append( acc ) nll = sum(nlls) accs = mean(accs)