diff --git a/codes/models/audio/tts/autoregressive_codegen.py b/codes/models/audio/tts/autoregressive_codegen.py index b820879b..22403233 100644 --- a/codes/models/audio/tts/autoregressive_codegen.py +++ b/codes/models/audio/tts/autoregressive_codegen.py @@ -159,6 +159,7 @@ class ConditioningEncoder(nn.Module): class AutoregressiveCodegen(nn.Module): def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1): super().__init__() + assert depth >= 8 # This is the minimum bound to support the context interleaving that happens later. self.START_TOKEN=8192 self.STOP_TOKEN=8193 @@ -170,7 +171,7 @@ class AutoregressiveCodegen(nn.Module): use_pos_emb=False, max_seq_len=-1, attn_layers = Encoder( - depth=depth//2, + depth=depth, heads=model_dim//64, dim=model_dim, attn_dropout=dropout, @@ -181,7 +182,8 @@ class AutoregressiveCodegen(nn.Module): rotary_pos_emb=True, attn_rel_pos_bias=True, )) - self.encoder.to_logits = nn.Identity() # This is unused. + self.encoder.norm = nn.Identity() # This layer and the next are unused. + self.encoder.to_logits = nn.Identity() self.decoder = TransformerWrapper( num_tokens=num_mel_tokens, use_pos_emb=False, @@ -224,12 +226,16 @@ class AutoregressiveCodegen(nn.Module): for i in range(conditioning_signal.shape[1]): cond_embs.append(self.mel_embedding(conditioning_signal[:, i])) cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True) - enc_text = self.encoder(text_codes, return_embeddings=True) - context = torch.cat([cond_emb, enc_text], dim=1) + _, enc_text = self.encoder(text_codes, return_hiddens=True) + # Interleave cond_emb into the first few contexts. + full_context = enc_text + full_context[1] = cond_emb + full_context[3] = cond_emb + full_context[6] = cond_emb # Execute the decoder dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1] - dec = self.decoder(dec_inputs, context=context) + dec = self.decoder(dec_inputs, full_context=full_context) if not return_loss: return dec loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes) @@ -261,10 +267,10 @@ def register_autoregressive_codegen(opt_net, opt): if __name__ == '__main__': - codegen = AutoregressiveCodegen(512, 20) + codegen = AutoregressiveCodegen(256, 10) torch.save(codegen.state_dict(), 'sample.pth') - codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200))) + #codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200))) codegen(torch.randint(0,256, (2,200)), torch.randn(2,80,120), torch.randint(0,8192, (2,350)), - torch.tensor([192,350])) \ No newline at end of file + torch.tensor([192,350])) diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index 453de22a..d2b22d66 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -841,13 +841,16 @@ class AttentionLayers(nn.Module): self, x, context = None, + full_context = None, # for passing a list of hidden states from an encoder mask = None, context_mask = None, attn_mask = None, mems = None, return_hiddens = False ): - assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True' + + assert not (self.cross_attend ^ (exists(context) or exists(full_context))), 'context must be passed in if cross_attend is set to True' + assert context is None or full_context is None, 'only one of full_context or context can be provided' hiddens = [] intermediates = [] @@ -861,9 +864,9 @@ class AttentionLayers(nn.Module): max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems))) rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) + cross_attn_count = 0 for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): if layer_type == 'a': - hiddens.append(x) layer_mem = mems.pop(0) if mems else None residual = x @@ -876,7 +879,10 @@ class AttentionLayers(nn.Module): if layer_type == 'a': out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem) elif layer_type == 'c': - out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn) + if exists(full_context): + out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, None, prev_attn) + else: + out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn) elif layer_type == 'f': out = checkpoint(block, x) @@ -896,6 +902,12 @@ class AttentionLayers(nn.Module): if exists(post_main_norm): x = post_main_norm(x) + if layer_type == 'c': + cross_attn_count += 1 + + if layer_type == 'f': + hiddens.append(x) + if return_hiddens: intermediates = LayerIntermediates( hiddens = hiddens, @@ -1024,7 +1036,7 @@ class TransformerWrapper(nn.Module): x, return_embeddings = False, mask = None, - return_mems = False, + return_hiddens = False, return_attn = False, mems = None, **kwargs @@ -1055,11 +1067,9 @@ class TransformerWrapper(nn.Module): out = self.to_logits(x) if not return_embeddings else x - if return_mems: + if return_hiddens: hiddens = intermediates.hiddens - new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens - new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) - return out, new_mems + return out, hiddens if return_attn: attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))