autoregressive_codegen r3
This commit is contained in:
parent
33ef17e9e5
commit
e011166dd6
|
@ -159,6 +159,7 @@ class ConditioningEncoder(nn.Module):
|
|||
class AutoregressiveCodegen(nn.Module):
|
||||
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
|
||||
super().__init__()
|
||||
assert depth >= 8 # This is the minimum bound to support the context interleaving that happens later.
|
||||
|
||||
self.START_TOKEN=8192
|
||||
self.STOP_TOKEN=8193
|
||||
|
@ -170,7 +171,7 @@ class AutoregressiveCodegen(nn.Module):
|
|||
use_pos_emb=False,
|
||||
max_seq_len=-1,
|
||||
attn_layers = Encoder(
|
||||
depth=depth//2,
|
||||
depth=depth,
|
||||
heads=model_dim//64,
|
||||
dim=model_dim,
|
||||
attn_dropout=dropout,
|
||||
|
@ -181,7 +182,8 @@ class AutoregressiveCodegen(nn.Module):
|
|||
rotary_pos_emb=True,
|
||||
attn_rel_pos_bias=True,
|
||||
))
|
||||
self.encoder.to_logits = nn.Identity() # This is unused.
|
||||
self.encoder.norm = nn.Identity() # This layer and the next are unused.
|
||||
self.encoder.to_logits = nn.Identity()
|
||||
self.decoder = TransformerWrapper(
|
||||
num_tokens=num_mel_tokens,
|
||||
use_pos_emb=False,
|
||||
|
@ -224,12 +226,16 @@ class AutoregressiveCodegen(nn.Module):
|
|||
for i in range(conditioning_signal.shape[1]):
|
||||
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
||||
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
||||
enc_text = self.encoder(text_codes, return_embeddings=True)
|
||||
context = torch.cat([cond_emb, enc_text], dim=1)
|
||||
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
||||
# Interleave cond_emb into the first few contexts.
|
||||
full_context = enc_text
|
||||
full_context[1] = cond_emb
|
||||
full_context[3] = cond_emb
|
||||
full_context[6] = cond_emb
|
||||
|
||||
# Execute the decoder
|
||||
dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
|
||||
dec = self.decoder(dec_inputs, context=context)
|
||||
dec = self.decoder(dec_inputs, full_context=full_context)
|
||||
if not return_loss:
|
||||
return dec
|
||||
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
||||
|
@ -261,10 +267,10 @@ def register_autoregressive_codegen(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
codegen = AutoregressiveCodegen(512, 20)
|
||||
codegen = AutoregressiveCodegen(256, 10)
|
||||
torch.save(codegen.state_dict(), 'sample.pth')
|
||||
codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
|
||||
#codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
|
||||
codegen(torch.randint(0,256, (2,200)),
|
||||
torch.randn(2,80,120),
|
||||
torch.randint(0,8192, (2,350)),
|
||||
torch.tensor([192,350]))
|
||||
torch.tensor([192,350]))
|
||||
|
|
|
@ -841,13 +841,16 @@ class AttentionLayers(nn.Module):
|
|||
self,
|
||||
x,
|
||||
context = None,
|
||||
full_context = None, # for passing a list of hidden states from an encoder
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
attn_mask = None,
|
||||
mems = None,
|
||||
return_hiddens = False
|
||||
):
|
||||
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
||||
|
||||
assert not (self.cross_attend ^ (exists(context) or exists(full_context))), 'context must be passed in if cross_attend is set to True'
|
||||
assert context is None or full_context is None, 'only one of full_context or context can be provided'
|
||||
|
||||
hiddens = []
|
||||
intermediates = []
|
||||
|
@ -861,9 +864,9 @@ class AttentionLayers(nn.Module):
|
|||
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
|
||||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||
|
||||
cross_attn_count = 0
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||
if layer_type == 'a':
|
||||
hiddens.append(x)
|
||||
layer_mem = mems.pop(0) if mems else None
|
||||
|
||||
residual = x
|
||||
|
@ -876,7 +879,10 @@ class AttentionLayers(nn.Module):
|
|||
if layer_type == 'a':
|
||||
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem)
|
||||
elif layer_type == 'c':
|
||||
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
|
||||
if exists(full_context):
|
||||
out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, None, prev_attn)
|
||||
else:
|
||||
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
|
||||
elif layer_type == 'f':
|
||||
out = checkpoint(block, x)
|
||||
|
||||
|
@ -896,6 +902,12 @@ class AttentionLayers(nn.Module):
|
|||
if exists(post_main_norm):
|
||||
x = post_main_norm(x)
|
||||
|
||||
if layer_type == 'c':
|
||||
cross_attn_count += 1
|
||||
|
||||
if layer_type == 'f':
|
||||
hiddens.append(x)
|
||||
|
||||
if return_hiddens:
|
||||
intermediates = LayerIntermediates(
|
||||
hiddens = hiddens,
|
||||
|
@ -1024,7 +1036,7 @@ class TransformerWrapper(nn.Module):
|
|||
x,
|
||||
return_embeddings = False,
|
||||
mask = None,
|
||||
return_mems = False,
|
||||
return_hiddens = False,
|
||||
return_attn = False,
|
||||
mems = None,
|
||||
**kwargs
|
||||
|
@ -1055,11 +1067,9 @@ class TransformerWrapper(nn.Module):
|
|||
|
||||
out = self.to_logits(x) if not return_embeddings else x
|
||||
|
||||
if return_mems:
|
||||
if return_hiddens:
|
||||
hiddens = intermediates.hiddens
|
||||
new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
|
||||
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
|
||||
return out, new_mems
|
||||
return out, hiddens
|
||||
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
|
|
Loading…
Reference in New Issue
Block a user