This commit is contained in:
James Betker 2022-05-26 20:19:09 -06:00
parent dd13b883ac
commit 8587a18717
2 changed files with 26 additions and 13 deletions

View File

@ -19,6 +19,16 @@ def is_sequence(t):
return t.dtype == torch.long
class MultiGroupEmbedding(nn.Module):
def __init__(self, tokens, groups, dim):
super().__init__()
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
def forward(self, x):
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
return torch.cat(h, dim=-1)
class TimestepResBlock(TimestepBlock):
def __init__(
self,
@ -114,8 +124,8 @@ class FlatDiffusion(nn.Module):
num_layers=8,
in_channels=256,
in_latent_channels=512,
in_vectors=8,
in_groups=8,
token_count=8,
in_groups=None,
out_channels=512, # mean and variance
dropout=0,
use_fp16=False,
@ -147,7 +157,10 @@ class FlatDiffusion(nn.Module):
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
self.embeddings = nn.ModuleList([nn.Embedding(in_vectors, model_channels//in_groups) for _ in range(in_groups)])
if in_groups is None:
self.embeddings = nn.Embedding(token_count, model_channels)
else:
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
@ -222,8 +235,8 @@ class FlatDiffusion(nn.Module):
if is_latent(aligned_conditioning):
code_emb = self.latent_conditioner(aligned_conditioning)
else:
code_emb = [embedding(aligned_conditioning[:, :, i]) for i, embedding in enumerate(self.embeddings)]
code_emb = torch.cat(code_emb, dim=-1).permute(0,2,1)
code_emb = self.embeddings(aligned_conditioning)
code_emb = code_emb.permute(0,2,1)
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
@ -245,19 +258,19 @@ class FlatDiffusion(nn.Module):
return expanded_code_emb, mel_pred
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
:param codes: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs.
"""
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
assert precomputed_aligned_embeddings is not None or (codes is not None and conditioning_input is not None)
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
unused_params = []
@ -269,8 +282,8 @@ class FlatDiffusion(nn.Module):
if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings
else:
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
if is_latent(aligned_conditioning):
code_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], True)
if is_latent(codes):
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
else:
unused_params.extend(list(self.latent_conditioner.parameters()))
@ -324,7 +337,7 @@ if __name__ == '__main__':
aligned_sequence = torch.randint(0,8,(2,100,8))
cond = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, train_mel_head=True)
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, in_groups=8, train_mel_head=True)
# Test with latent aligned conditioning
#o = model(clip, ts, aligned_latent, cond)
# Test with sequence aligned conditioning

View File

@ -138,7 +138,7 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
self.rotary_embeddings = None if rotary_emb_dim is None else RotaryEmbedding(rotary_emb_dim)
self.layers = SpecialSequential(*[AttentionBlock(model_channels, model_channels // 64, dropout) for _ in range(num_layers)])
self.out = nn.Sequential(
@ -217,7 +217,7 @@ class TransformerDiffusion(nn.Module):
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb
x = self.inp_block(x).permute(0,2,1)
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
rotary_pos_emb = None if self.rotary_embeddings is None else self.rotary_embeddings(x.shape[1], x.device)
x = self.layers(x, code_emb, blk_emb, rotary_pos_emb)
x = x.float().permute(0,2,1)