This commit is contained in:
mrq 2024-11-07 09:19:21 -06:00
parent d13ab00ad8
commit ed174c589e

View File

@ -485,7 +485,7 @@ class Base(nn.Module):
# it would be nicer for these to be a token or live inside an embedding
self.sep = nn.Parameter(torch.randn(d_model))
self.dropout_token = nn.Parameter(torch.randn(d_model))
self.mask_token = dropout_token # alias (hopefully) to the above
self.mask_token = self.dropout_token # alias (hopefully) to the above
if self.version == 1: # legacy
n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom