autoregressive_codegen r3

This commit is contained in:
James Betker 2022-04-06 21:04:23 -06:00
parent 33ef17e9e5
commit e011166dd6
2 changed files with 32 additions and 16 deletions

View File

@ -159,6 +159,7 @@ class ConditioningEncoder(nn.Module):
class AutoregressiveCodegen(nn.Module):
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
super().__init__()
assert depth >= 8 # This is the minimum bound to support the context interleaving that happens later.
self.START_TOKEN=8192
self.STOP_TOKEN=8193
@ -170,7 +171,7 @@ class AutoregressiveCodegen(nn.Module):
use_pos_emb=False,
max_seq_len=-1,
attn_layers = Encoder(
depth=depth//2,
depth=depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
@ -181,7 +182,8 @@ class AutoregressiveCodegen(nn.Module):
rotary_pos_emb=True,
attn_rel_pos_bias=True,
))
self.encoder.to_logits = nn.Identity() # This is unused.
self.encoder.norm = nn.Identity() # This layer and the next are unused.
self.encoder.to_logits = nn.Identity()
self.decoder = TransformerWrapper(
num_tokens=num_mel_tokens,
use_pos_emb=False,
@ -224,12 +226,16 @@ class AutoregressiveCodegen(nn.Module):
for i in range(conditioning_signal.shape[1]):
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
enc_text = self.encoder(text_codes, return_embeddings=True)
context = torch.cat([cond_emb, enc_text], dim=1)
_, enc_text = self.encoder(text_codes, return_hiddens=True)
# Interleave cond_emb into the first few contexts.
full_context = enc_text
full_context[1] = cond_emb
full_context[3] = cond_emb
full_context[6] = cond_emb
# Execute the decoder
dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
dec = self.decoder(dec_inputs, context=context)
dec = self.decoder(dec_inputs, full_context=full_context)
if not return_loss:
return dec
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
@ -261,9 +267,9 @@ def register_autoregressive_codegen(opt_net, opt):
if __name__ == '__main__':
codegen = AutoregressiveCodegen(512, 20)
codegen = AutoregressiveCodegen(256, 10)
torch.save(codegen.state_dict(), 'sample.pth')
codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
#codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
codegen(torch.randint(0,256, (2,200)),
torch.randn(2,80,120),
torch.randint(0,8192, (2,350)),

View File

@ -841,13 +841,16 @@ class AttentionLayers(nn.Module):
self,
x,
context = None,
full_context = None, # for passing a list of hidden states from an encoder
mask = None,
context_mask = None,
attn_mask = None,
mems = None,
return_hiddens = False
):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
assert not (self.cross_attend ^ (exists(context) or exists(full_context))), 'context must be passed in if cross_attend is set to True'
assert context is None or full_context is None, 'only one of full_context or context can be provided'
hiddens = []
intermediates = []
@ -861,9 +864,9 @@ class AttentionLayers(nn.Module):
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
cross_attn_count = 0
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
if layer_type == 'a':
hiddens.append(x)
layer_mem = mems.pop(0) if mems else None
residual = x
@ -876,7 +879,10 @@ class AttentionLayers(nn.Module):
if layer_type == 'a':
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem)
elif layer_type == 'c':
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
if exists(full_context):
out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, None, prev_attn)
else:
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
elif layer_type == 'f':
out = checkpoint(block, x)
@ -896,6 +902,12 @@ class AttentionLayers(nn.Module):
if exists(post_main_norm):
x = post_main_norm(x)
if layer_type == 'c':
cross_attn_count += 1
if layer_type == 'f':
hiddens.append(x)
if return_hiddens:
intermediates = LayerIntermediates(
hiddens = hiddens,
@ -1024,7 +1036,7 @@ class TransformerWrapper(nn.Module):
x,
return_embeddings = False,
mask = None,
return_mems = False,
return_hiddens = False,
return_attn = False,
mems = None,
**kwargs
@ -1055,11 +1067,9 @@ class TransformerWrapper(nn.Module):
out = self.to_logits(x) if not return_embeddings else x
if return_mems:
if return_hiddens:
hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
return out, new_mems
return out, hiddens
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))