From 85b9dd47c14d62e7005d749cf153f44abaede10c Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 19 Mar 2025 13:31:50 -0500 Subject: [PATCH] ugh --- vall_e/models/base_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 55a2eb9..e376209 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -950,7 +950,7 @@ class Base_V2(nn.Module): if causal or self.predict_causally: l = self.causal_size logit = logit[..., :-l, :] # shift the target so that token n... - token = sequence[..., l:] # ...predicts token n + 1 + token = token[..., l:] # ...predicts token n + 1 """ if self.logit_normalization: @@ -972,7 +972,7 @@ class Base_V2(nn.Module): if causal or self.predict_causally: l = self.causal_size logit = logit[..., :-l, :] # shift the target so that token n... - token = sequence[..., l:] # ...predicts token n + 1 + token = token[..., l:] # ...predicts token n + 1 """ if self.logit_normalization: