diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 954907b..32bbb16 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -485,7 +485,7 @@ class Base(nn.Module): # it would be nicer for these to be a token or live inside an embedding self.sep = nn.Parameter(torch.randn(d_model)) self.dropout_token = nn.Parameter(torch.randn(d_model)) - self.mask_token = dropout_token # alias (hopefully) to the above + self.mask_token = self.dropout_token # alias (hopefully) to the above if self.version == 1: # legacy n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom