fd fix
This commit is contained in:
parent
dd13b883ac
commit
8587a18717
|
@ -19,6 +19,16 @@ def is_sequence(t):
|
||||||
return t.dtype == torch.long
|
return t.dtype == torch.long
|
||||||
|
|
||||||
|
|
||||||
|
class MultiGroupEmbedding(nn.Module):
|
||||||
|
def __init__(self, tokens, groups, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||||
|
return torch.cat(h, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
class TimestepResBlock(TimestepBlock):
|
class TimestepResBlock(TimestepBlock):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -114,8 +124,8 @@ class FlatDiffusion(nn.Module):
|
||||||
num_layers=8,
|
num_layers=8,
|
||||||
in_channels=256,
|
in_channels=256,
|
||||||
in_latent_channels=512,
|
in_latent_channels=512,
|
||||||
in_vectors=8,
|
token_count=8,
|
||||||
in_groups=8,
|
in_groups=None,
|
||||||
out_channels=512, # mean and variance
|
out_channels=512, # mean and variance
|
||||||
dropout=0,
|
dropout=0,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
|
@ -147,7 +157,10 @@ class FlatDiffusion(nn.Module):
|
||||||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||||
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||||
# transformer network.
|
# transformer network.
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(in_vectors, model_channels//in_groups) for _ in range(in_groups)])
|
if in_groups is None:
|
||||||
|
self.embeddings = nn.Embedding(token_count, model_channels)
|
||||||
|
else:
|
||||||
|
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
|
||||||
self.latent_conditioner = nn.Sequential(
|
self.latent_conditioner = nn.Sequential(
|
||||||
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
||||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
|
@ -222,8 +235,8 @@ class FlatDiffusion(nn.Module):
|
||||||
if is_latent(aligned_conditioning):
|
if is_latent(aligned_conditioning):
|
||||||
code_emb = self.latent_conditioner(aligned_conditioning)
|
code_emb = self.latent_conditioner(aligned_conditioning)
|
||||||
else:
|
else:
|
||||||
code_emb = [embedding(aligned_conditioning[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
code_emb = self.embeddings(aligned_conditioning)
|
||||||
code_emb = torch.cat(code_emb, dim=-1).permute(0,2,1)
|
code_emb = code_emb.permute(0,2,1)
|
||||||
|
|
||||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
|
@ -245,19 +258,19 @@ class FlatDiffusion(nn.Module):
|
||||||
return expanded_code_emb, mel_pred
|
return expanded_code_emb, mel_pred
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
|
|
||||||
:param x: an [N x C x ...] Tensor of inputs.
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
:param timesteps: a 1-D batch of timesteps.
|
:param timesteps: a 1-D batch of timesteps.
|
||||||
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
:param codes: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
||||||
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
|
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
|
||||||
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
|
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
|
||||||
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
||||||
:return: an [N x C x ...] Tensor of outputs.
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
"""
|
"""
|
||||||
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
|
assert precomputed_aligned_embeddings is not None or (codes is not None and conditioning_input is not None)
|
||||||
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
|
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
|
||||||
|
|
||||||
unused_params = []
|
unused_params = []
|
||||||
|
@ -269,8 +282,8 @@ class FlatDiffusion(nn.Module):
|
||||||
if precomputed_aligned_embeddings is not None:
|
if precomputed_aligned_embeddings is not None:
|
||||||
code_emb = precomputed_aligned_embeddings
|
code_emb = precomputed_aligned_embeddings
|
||||||
else:
|
else:
|
||||||
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
|
code_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], True)
|
||||||
if is_latent(aligned_conditioning):
|
if is_latent(codes):
|
||||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||||
else:
|
else:
|
||||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||||
|
@ -324,7 +337,7 @@ if __name__ == '__main__':
|
||||||
aligned_sequence = torch.randint(0,8,(2,100,8))
|
aligned_sequence = torch.randint(0,8,(2,100,8))
|
||||||
cond = torch.randn(2, 256, 400)
|
cond = torch.randn(2, 256, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, train_mel_head=True)
|
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, in_groups=8, train_mel_head=True)
|
||||||
# Test with latent aligned conditioning
|
# Test with latent aligned conditioning
|
||||||
#o = model(clip, ts, aligned_latent, cond)
|
#o = model(clip, ts, aligned_latent, cond)
|
||||||
# Test with sequence aligned conditioning
|
# Test with sequence aligned conditioning
|
||||||
|
|
|
@ -138,7 +138,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
||||||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
self.rotary_embeddings = None if rotary_emb_dim is None else RotaryEmbedding(rotary_emb_dim)
|
||||||
self.layers = SpecialSequential(*[AttentionBlock(model_channels, model_channels // 64, dropout) for _ in range(num_layers)])
|
self.layers = SpecialSequential(*[AttentionBlock(model_channels, model_channels // 64, dropout) for _ in range(num_layers)])
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
|
@ -217,7 +217,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb
|
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb
|
||||||
x = self.inp_block(x).permute(0,2,1)
|
x = self.inp_block(x).permute(0,2,1)
|
||||||
|
|
||||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
rotary_pos_emb = None if self.rotary_embeddings is None else self.rotary_embeddings(x.shape[1], x.device)
|
||||||
x = self.layers(x, code_emb, blk_emb, rotary_pos_emb)
|
x = self.layers(x, code_emb, blk_emb, rotary_pos_emb)
|
||||||
|
|
||||||
x = x.float().permute(0,2,1)
|
x = x.float().permute(0,2,1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user