nevermind thats slow

This commit is contained in:
mrq 2025-02-14 16:35:17 -06:00
parent 285e493b12
commit 13c3a08853

View File

@ -370,16 +370,12 @@ class AudioEncoder(nn.Module):
class AudioDecoder(nn.Module):
def __init__(
self,
levels,
d_model,
hidden_size,
vocab_size,
):
super().__init__()
hidden_size *= levels
vocab_size *= levels
self.vocab_size = vocab_size
self.up = nn.Linear( d_model, hidden_size )
self.down = nn.Linear( hidden_size, vocab_size )
@ -715,8 +711,6 @@ class Base(nn.Module):
self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
self.monolithic_audio_encoder = False # monolithic sounds bad
if self.version >= 7:
dec_dim = d_model * 4
if self.monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 1, # masked token
@ -736,10 +730,9 @@ class Base(nn.Module):
)
self.audio_decoder = AudioDecoder(
self.n_resp_levels,
d_model,
dec_dim,
n_audio_tokens + 1,
d_model * 2,
(n_audio_tokens + 1) * self.n_resp_levels,
)
if attention_backend == "auto":