new attempt

This commit is contained in:
James Betker 2022-05-20 17:04:22 -06:00
parent 968660c248
commit be937d202e
2 changed files with 18 additions and 7 deletions

View File

@ -5,6 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from models.arch_util import ResBlock
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock
from trainer.networks import register_model
@ -18,7 +19,7 @@ def is_sequence(t):
return t.dtype == torch.long
class ResBlock(TimestepBlock):
class TimestepResBlock(TimestepBlock):
def __init__(
self,
channels,
@ -98,7 +99,7 @@ class ResBlock(TimestepBlock):
class DiffusionLayer(TimestepBlock):
def __init__(self, model_channels, dropout, num_heads):
super().__init__()
self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
self.resblk = TimestepResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
def forward(self, x, time_emb):
@ -122,6 +123,7 @@ class FlatDiffusion(nn.Module):
# Parameters for regularization.
layer_drop=.1,
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
train_mel_head=False,
):
super().__init__()
@ -154,9 +156,11 @@ class FlatDiffusion(nn.Module):
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
)
self.code_converter = nn.Sequential(
ResBlock(dims=1, channels=model_channels, dropout=dropout),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
ResBlock(dims=1, channels=model_channels, dropout=dropout),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
ResBlock(dims=1, channels=model_channels, dropout=dropout),
)
self.code_norm = normalization(model_channels)
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
@ -176,7 +180,7 @@ class FlatDiffusion(nn.Module):
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
[ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
[TimestepResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
self.out = nn.Sequential(
normalization(model_channels),
@ -184,6 +188,13 @@ class FlatDiffusion(nn.Module):
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
)
if train_mel_head:
for m in [self.conditioning_timestep_integrator, self.integrating_conv, self.layers,
self.out]:
for p in m.parameters():
p.requires_grad = False
p.DO_NOT_TRAIN = True
def get_grad_norm_parameter_groups(self):
groups = {
'minicoder': list(self.contextual_embedder.parameters()),
@ -213,7 +224,6 @@ class FlatDiffusion(nn.Module):
else:
code_emb = [embedding(aligned_conditioning[:, :, i]) for i, embedding in enumerate(self.embeddings)]
code_emb = torch.cat(code_emb, dim=-1).permute(0,2,1)
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
@ -224,6 +234,7 @@ class FlatDiffusion(nn.Module):
code_emb)
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
expanded_code_emb = self.code_converter(expanded_code_emb)
expanded_code_emb = self.code_norm(expanded_code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
if not return_code_pred:
return expanded_code_emb
@ -313,7 +324,7 @@ if __name__ == '__main__':
aligned_sequence = torch.randint(0,8,(2,100,8))
cond = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5)
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, train_mel_head=True)
# Test with latent aligned conditioning
#o = model(clip, ts, aligned_latent, cond)
# Test with sequence aligned conditioning

View File

@ -25,7 +25,7 @@ def recover_codegroups(codes, groups):
if __name__ == '__main__':
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0,
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True)
model.load_state_dict(torch.load("X:\\dlas\\experiments\\train_music_mel2vec\\models\\29000_generator_ema.pth"))
model.load_state_dict(torch.load("../experiments/m2v_music.pth"))
model.eval()
wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050)