nevermind thats slow
This commit is contained in:
parent
285e493b12
commit
13c3a08853
|
@ -370,16 +370,12 @@ class AudioEncoder(nn.Module):
|
|||
class AudioDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
levels,
|
||||
d_model,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
hidden_size *= levels
|
||||
vocab_size *= levels
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.up = nn.Linear( d_model, hidden_size )
|
||||
self.down = nn.Linear( hidden_size, vocab_size )
|
||||
|
@ -715,8 +711,6 @@ class Base(nn.Module):
|
|||
self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
|
||||
self.monolithic_audio_encoder = False # monolithic sounds bad
|
||||
if self.version >= 7:
|
||||
dec_dim = d_model * 4
|
||||
|
||||
if self.monolithic_audio_encoder:
|
||||
self.audio_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 1, # masked token
|
||||
|
@ -736,10 +730,9 @@ class Base(nn.Module):
|
|||
)
|
||||
|
||||
self.audio_decoder = AudioDecoder(
|
||||
self.n_resp_levels,
|
||||
d_model,
|
||||
dec_dim,
|
||||
n_audio_tokens + 1,
|
||||
d_model * 2,
|
||||
(n_audio_tokens + 1) * self.n_resp_levels,
|
||||
)
|
||||
|
||||
if attention_backend == "auto":
|
||||
|
|
Loading…
Reference in New Issue
Block a user