diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 86d4960..a586103 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -376,7 +376,9 @@ class Base_V2(nn.Module): elif attention_backend == "fused_attn": self.l_padding = 128 - if self.arch_type in ["llama"]: + if self.arch_type in ["none"]: + self.model = None + elif self.arch_type in ["llama"]: self.model_config = LlamaConfig( vocab_size=n_vocab, hidden_size=d_model, @@ -423,8 +425,10 @@ class Base_V2(nn.Module): attentions = None hidden_states = None + if self.arch_type in ["none"] or self.model is None: + ... # HF transformer derived model - if self.arch_type in ["llama"]: + elif self.arch_type in ["llama"]: kwargs = dict( inputs_embeds=x, attention_mask=m, @@ -988,7 +992,7 @@ class Base_V2(nn.Module): nll, metrics = _calc_loss( logit, sequence, causal, level ) nlls.append( nll ) - accs.append( accs ) + accs.append( acc ) nll = sum(nlls) accs = mean(accs)