autoregressive_codegen r3
This commit is contained in:
parent
33ef17e9e5
commit
e011166dd6
|
@ -159,6 +159,7 @@ class ConditioningEncoder(nn.Module):
|
||||||
class AutoregressiveCodegen(nn.Module):
|
class AutoregressiveCodegen(nn.Module):
|
||||||
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
|
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert depth >= 8 # This is the minimum bound to support the context interleaving that happens later.
|
||||||
|
|
||||||
self.START_TOKEN=8192
|
self.START_TOKEN=8192
|
||||||
self.STOP_TOKEN=8193
|
self.STOP_TOKEN=8193
|
||||||
|
@ -170,7 +171,7 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
use_pos_emb=False,
|
use_pos_emb=False,
|
||||||
max_seq_len=-1,
|
max_seq_len=-1,
|
||||||
attn_layers = Encoder(
|
attn_layers = Encoder(
|
||||||
depth=depth//2,
|
depth=depth,
|
||||||
heads=model_dim//64,
|
heads=model_dim//64,
|
||||||
dim=model_dim,
|
dim=model_dim,
|
||||||
attn_dropout=dropout,
|
attn_dropout=dropout,
|
||||||
|
@ -181,7 +182,8 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
attn_rel_pos_bias=True,
|
attn_rel_pos_bias=True,
|
||||||
))
|
))
|
||||||
self.encoder.to_logits = nn.Identity() # This is unused.
|
self.encoder.norm = nn.Identity() # This layer and the next are unused.
|
||||||
|
self.encoder.to_logits = nn.Identity()
|
||||||
self.decoder = TransformerWrapper(
|
self.decoder = TransformerWrapper(
|
||||||
num_tokens=num_mel_tokens,
|
num_tokens=num_mel_tokens,
|
||||||
use_pos_emb=False,
|
use_pos_emb=False,
|
||||||
|
@ -224,12 +226,16 @@ class AutoregressiveCodegen(nn.Module):
|
||||||
for i in range(conditioning_signal.shape[1]):
|
for i in range(conditioning_signal.shape[1]):
|
||||||
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
|
||||||
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
||||||
enc_text = self.encoder(text_codes, return_embeddings=True)
|
_, enc_text = self.encoder(text_codes, return_hiddens=True)
|
||||||
context = torch.cat([cond_emb, enc_text], dim=1)
|
# Interleave cond_emb into the first few contexts.
|
||||||
|
full_context = enc_text
|
||||||
|
full_context[1] = cond_emb
|
||||||
|
full_context[3] = cond_emb
|
||||||
|
full_context[6] = cond_emb
|
||||||
|
|
||||||
# Execute the decoder
|
# Execute the decoder
|
||||||
dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
|
dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
|
||||||
dec = self.decoder(dec_inputs, context=context)
|
dec = self.decoder(dec_inputs, full_context=full_context)
|
||||||
if not return_loss:
|
if not return_loss:
|
||||||
return dec
|
return dec
|
||||||
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
||||||
|
@ -261,10 +267,10 @@ def register_autoregressive_codegen(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
codegen = AutoregressiveCodegen(512, 20)
|
codegen = AutoregressiveCodegen(256, 10)
|
||||||
torch.save(codegen.state_dict(), 'sample.pth')
|
torch.save(codegen.state_dict(), 'sample.pth')
|
||||||
codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
|
#codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
|
||||||
codegen(torch.randint(0,256, (2,200)),
|
codegen(torch.randint(0,256, (2,200)),
|
||||||
torch.randn(2,80,120),
|
torch.randn(2,80,120),
|
||||||
torch.randint(0,8192, (2,350)),
|
torch.randint(0,8192, (2,350)),
|
||||||
torch.tensor([192,350]))
|
torch.tensor([192,350]))
|
||||||
|
|
|
@ -841,13 +841,16 @@ class AttentionLayers(nn.Module):
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
context = None,
|
context = None,
|
||||||
|
full_context = None, # for passing a list of hidden states from an encoder
|
||||||
mask = None,
|
mask = None,
|
||||||
context_mask = None,
|
context_mask = None,
|
||||||
attn_mask = None,
|
attn_mask = None,
|
||||||
mems = None,
|
mems = None,
|
||||||
return_hiddens = False
|
return_hiddens = False
|
||||||
):
|
):
|
||||||
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
|
||||||
|
assert not (self.cross_attend ^ (exists(context) or exists(full_context))), 'context must be passed in if cross_attend is set to True'
|
||||||
|
assert context is None or full_context is None, 'only one of full_context or context can be provided'
|
||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
intermediates = []
|
intermediates = []
|
||||||
|
@ -861,9 +864,9 @@ class AttentionLayers(nn.Module):
|
||||||
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
|
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
|
||||||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||||
|
|
||||||
|
cross_attn_count = 0
|
||||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||||
if layer_type == 'a':
|
if layer_type == 'a':
|
||||||
hiddens.append(x)
|
|
||||||
layer_mem = mems.pop(0) if mems else None
|
layer_mem = mems.pop(0) if mems else None
|
||||||
|
|
||||||
residual = x
|
residual = x
|
||||||
|
@ -876,7 +879,10 @@ class AttentionLayers(nn.Module):
|
||||||
if layer_type == 'a':
|
if layer_type == 'a':
|
||||||
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem)
|
out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem)
|
||||||
elif layer_type == 'c':
|
elif layer_type == 'c':
|
||||||
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
|
if exists(full_context):
|
||||||
|
out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None, None, prev_attn)
|
||||||
|
else:
|
||||||
|
out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
|
||||||
elif layer_type == 'f':
|
elif layer_type == 'f':
|
||||||
out = checkpoint(block, x)
|
out = checkpoint(block, x)
|
||||||
|
|
||||||
|
@ -896,6 +902,12 @@ class AttentionLayers(nn.Module):
|
||||||
if exists(post_main_norm):
|
if exists(post_main_norm):
|
||||||
x = post_main_norm(x)
|
x = post_main_norm(x)
|
||||||
|
|
||||||
|
if layer_type == 'c':
|
||||||
|
cross_attn_count += 1
|
||||||
|
|
||||||
|
if layer_type == 'f':
|
||||||
|
hiddens.append(x)
|
||||||
|
|
||||||
if return_hiddens:
|
if return_hiddens:
|
||||||
intermediates = LayerIntermediates(
|
intermediates = LayerIntermediates(
|
||||||
hiddens = hiddens,
|
hiddens = hiddens,
|
||||||
|
@ -1024,7 +1036,7 @@ class TransformerWrapper(nn.Module):
|
||||||
x,
|
x,
|
||||||
return_embeddings = False,
|
return_embeddings = False,
|
||||||
mask = None,
|
mask = None,
|
||||||
return_mems = False,
|
return_hiddens = False,
|
||||||
return_attn = False,
|
return_attn = False,
|
||||||
mems = None,
|
mems = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
|
@ -1055,11 +1067,9 @@ class TransformerWrapper(nn.Module):
|
||||||
|
|
||||||
out = self.to_logits(x) if not return_embeddings else x
|
out = self.to_logits(x) if not return_embeddings else x
|
||||||
|
|
||||||
if return_mems:
|
if return_hiddens:
|
||||||
hiddens = intermediates.hiddens
|
hiddens = intermediates.hiddens
|
||||||
new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
|
return out, hiddens
|
||||||
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
|
|
||||||
return out, new_mems
|
|
||||||
|
|
||||||
if return_attn:
|
if return_attn:
|
||||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user