forked from mrq/DL-Art-School
fd fix
This commit is contained in:
parent
dd13b883ac
commit
8587a18717
|
@ -19,6 +19,16 @@ def is_sequence(t):
|
|||
return t.dtype == torch.long
|
||||
|
||||
|
||||
class MultiGroupEmbedding(nn.Module):
|
||||
def __init__(self, tokens, groups, dim):
|
||||
super().__init__()
|
||||
self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)])
|
||||
|
||||
def forward(self, x):
|
||||
h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)]
|
||||
return torch.cat(h, dim=-1)
|
||||
|
||||
|
||||
class TimestepResBlock(TimestepBlock):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -114,8 +124,8 @@ class FlatDiffusion(nn.Module):
|
|||
num_layers=8,
|
||||
in_channels=256,
|
||||
in_latent_channels=512,
|
||||
in_vectors=8,
|
||||
in_groups=8,
|
||||
token_count=8,
|
||||
in_groups=None,
|
||||
out_channels=512, # mean and variance
|
||||
dropout=0,
|
||||
use_fp16=False,
|
||||
|
@ -147,7 +157,10 @@ class FlatDiffusion(nn.Module):
|
|||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||
# transformer network.
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(in_vectors, model_channels//in_groups) for _ in range(in_groups)])
|
||||
if in_groups is None:
|
||||
self.embeddings = nn.Embedding(token_count, model_channels)
|
||||
else:
|
||||
self.embeddings = MultiGroupEmbedding(token_count, in_groups, model_channels)
|
||||
self.latent_conditioner = nn.Sequential(
|
||||
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
|
@ -222,8 +235,8 @@ class FlatDiffusion(nn.Module):
|
|||
if is_latent(aligned_conditioning):
|
||||
code_emb = self.latent_conditioner(aligned_conditioning)
|
||||
else:
|
||||
code_emb = [embedding(aligned_conditioning[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||
code_emb = torch.cat(code_emb, dim=-1).permute(0,2,1)
|
||||
code_emb = self.embeddings(aligned_conditioning)
|
||||
code_emb = code_emb.permute(0,2,1)
|
||||
|
||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
|
@ -245,19 +258,19 @@ class FlatDiffusion(nn.Module):
|
|||
return expanded_code_emb, mel_pred
|
||||
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||
def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
||||
:param codes: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
||||
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
|
||||
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
|
||||
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
|
||||
assert precomputed_aligned_embeddings is not None or (codes is not None and conditioning_input is not None)
|
||||
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
|
||||
|
||||
unused_params = []
|
||||
|
@ -269,8 +282,8 @@ class FlatDiffusion(nn.Module):
|
|||
if precomputed_aligned_embeddings is not None:
|
||||
code_emb = precomputed_aligned_embeddings
|
||||
else:
|
||||
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
|
||||
if is_latent(aligned_conditioning):
|
||||
code_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], True)
|
||||
if is_latent(codes):
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
else:
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||
|
@ -324,7 +337,7 @@ if __name__ == '__main__':
|
|||
aligned_sequence = torch.randint(0,8,(2,100,8))
|
||||
cond = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, train_mel_head=True)
|
||||
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, in_groups=8, train_mel_head=True)
|
||||
# Test with latent aligned conditioning
|
||||
#o = model(clip, ts, aligned_latent, cond)
|
||||
# Test with sequence aligned conditioning
|
||||
|
|
|
@ -138,7 +138,7 @@ class TransformerDiffusion(nn.Module):
|
|||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
||||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim)
|
||||
self.rotary_embeddings = None if rotary_emb_dim is None else RotaryEmbedding(rotary_emb_dim)
|
||||
self.layers = SpecialSequential(*[AttentionBlock(model_channels, model_channels // 64, dropout) for _ in range(num_layers)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
|
@ -217,7 +217,7 @@ class TransformerDiffusion(nn.Module):
|
|||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb
|
||||
x = self.inp_block(x).permute(0,2,1)
|
||||
|
||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||
rotary_pos_emb = None if self.rotary_embeddings is None else self.rotary_embeddings(x.shape[1], x.device)
|
||||
x = self.layers(x, code_emb, blk_emb, rotary_pos_emb)
|
||||
|
||||
x = x.float().permute(0,2,1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user