This commit is contained in:
James Betker 2022-04-02 21:57:00 -06:00
parent 4c6bdfc9e2
commit b6afc4d542

View File

@ -205,6 +205,7 @@ class AutoregressiveCodegen(nn.Module):
self.encoder = CheckpointedXTransformerWrapper(
num_tokens=num_text_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers = Encoder(
depth=depth//2,
heads=model_dim//64,
@ -220,6 +221,7 @@ class AutoregressiveCodegen(nn.Module):
self.decoder = CheckpointedXTransformerWrapper(
num_tokens=num_mel_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers=Decoder(
depth=depth,
heads=model_dim//64,