new attempt
This commit is contained in:
parent
968660c248
commit
be937d202e
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
|
||||
from models.arch_util import ResBlock
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock
|
||||
from trainer.networks import register_model
|
||||
|
@ -18,7 +19,7 @@ def is_sequence(t):
|
|||
return t.dtype == torch.long
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
class TimestepResBlock(TimestepBlock):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
|
@ -98,7 +99,7 @@ class ResBlock(TimestepBlock):
|
|||
class DiffusionLayer(TimestepBlock):
|
||||
def __init__(self, model_channels, dropout, num_heads):
|
||||
super().__init__()
|
||||
self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
|
||||
self.resblk = TimestepResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
|
||||
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
|
||||
|
||||
def forward(self, x, time_emb):
|
||||
|
@ -122,6 +123,7 @@ class FlatDiffusion(nn.Module):
|
|||
# Parameters for regularization.
|
||||
layer_drop=.1,
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
train_mel_head=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -154,9 +156,11 @@ class FlatDiffusion(nn.Module):
|
|||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
)
|
||||
self.code_converter = nn.Sequential(
|
||||
ResBlock(dims=1, channels=model_channels, dropout=dropout),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
ResBlock(dims=1, channels=model_channels, dropout=dropout),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
ResBlock(dims=1, channels=model_channels, dropout=dropout),
|
||||
)
|
||||
self.code_norm = normalization(model_channels)
|
||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
||||
|
@ -176,7 +180,7 @@ class FlatDiffusion(nn.Module):
|
|||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
|
||||
[ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
|
||||
[TimestepResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
|
@ -184,6 +188,13 @@ class FlatDiffusion(nn.Module):
|
|||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
if train_mel_head:
|
||||
for m in [self.conditioning_timestep_integrator, self.integrating_conv, self.layers,
|
||||
self.out]:
|
||||
for p in m.parameters():
|
||||
p.requires_grad = False
|
||||
p.DO_NOT_TRAIN = True
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
groups = {
|
||||
'minicoder': list(self.contextual_embedder.parameters()),
|
||||
|
@ -213,7 +224,6 @@ class FlatDiffusion(nn.Module):
|
|||
else:
|
||||
code_emb = [embedding(aligned_conditioning[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||
code_emb = torch.cat(code_emb, dim=-1).permute(0,2,1)
|
||||
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
||||
|
||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
|
@ -224,6 +234,7 @@ class FlatDiffusion(nn.Module):
|
|||
code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
||||
expanded_code_emb = self.code_converter(expanded_code_emb)
|
||||
expanded_code_emb = self.code_norm(expanded_code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
||||
|
||||
if not return_code_pred:
|
||||
return expanded_code_emb
|
||||
|
@ -313,7 +324,7 @@ if __name__ == '__main__':
|
|||
aligned_sequence = torch.randint(0,8,(2,100,8))
|
||||
cond = torch.randn(2, 256, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5)
|
||||
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, train_mel_head=True)
|
||||
# Test with latent aligned conditioning
|
||||
#o = model(clip, ts, aligned_latent, cond)
|
||||
# Test with sequence aligned conditioning
|
||||
|
|
|
@ -25,7 +25,7 @@ def recover_codegroups(codes, groups):
|
|||
if __name__ == '__main__':
|
||||
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0,
|
||||
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True)
|
||||
model.load_state_dict(torch.load("X:\\dlas\\experiments\\train_music_mel2vec\\models\\29000_generator_ema.pth"))
|
||||
model.load_state_dict(torch.load("../experiments/m2v_music.pth"))
|
||||
model.eval()
|
||||
|
||||
wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050)
|
||||
|
|
Loading…
Reference in New Issue
Block a user