new attempt
This commit is contained in:
parent
968660c248
commit
be937d202e
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
|
|
||||||
|
from models.arch_util import ResBlock
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock
|
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
@ -18,7 +19,7 @@ def is_sequence(t):
|
||||||
return t.dtype == torch.long
|
return t.dtype == torch.long
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(TimestepBlock):
|
class TimestepResBlock(TimestepBlock):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels,
|
channels,
|
||||||
|
@ -98,7 +99,7 @@ class ResBlock(TimestepBlock):
|
||||||
class DiffusionLayer(TimestepBlock):
|
class DiffusionLayer(TimestepBlock):
|
||||||
def __init__(self, model_channels, dropout, num_heads):
|
def __init__(self, model_channels, dropout, num_heads):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
|
self.resblk = TimestepResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
|
||||||
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
|
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
|
||||||
|
|
||||||
def forward(self, x, time_emb):
|
def forward(self, x, time_emb):
|
||||||
|
@ -122,6 +123,7 @@ class FlatDiffusion(nn.Module):
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
layer_drop=.1,
|
layer_drop=.1,
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
|
train_mel_head=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -154,9 +156,11 @@ class FlatDiffusion(nn.Module):
|
||||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
)
|
)
|
||||||
self.code_converter = nn.Sequential(
|
self.code_converter = nn.Sequential(
|
||||||
|
ResBlock(dims=1, channels=model_channels, dropout=dropout),
|
||||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
|
ResBlock(dims=1, channels=model_channels, dropout=dropout),
|
||||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
ResBlock(dims=1, channels=model_channels, dropout=dropout),
|
||||||
)
|
)
|
||||||
self.code_norm = normalization(model_channels)
|
self.code_norm = normalization(model_channels)
|
||||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
||||||
|
@ -176,7 +180,7 @@ class FlatDiffusion(nn.Module):
|
||||||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
|
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
|
||||||
[ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
|
[TimestepResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(model_channels),
|
normalization(model_channels),
|
||||||
|
@ -184,6 +188,13 @@ class FlatDiffusion(nn.Module):
|
||||||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if train_mel_head:
|
||||||
|
for m in [self.conditioning_timestep_integrator, self.integrating_conv, self.layers,
|
||||||
|
self.out]:
|
||||||
|
for p in m.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
groups = {
|
groups = {
|
||||||
'minicoder': list(self.contextual_embedder.parameters()),
|
'minicoder': list(self.contextual_embedder.parameters()),
|
||||||
|
@ -213,7 +224,6 @@ class FlatDiffusion(nn.Module):
|
||||||
else:
|
else:
|
||||||
code_emb = [embedding(aligned_conditioning[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
code_emb = [embedding(aligned_conditioning[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||||
code_emb = torch.cat(code_emb, dim=-1).permute(0,2,1)
|
code_emb = torch.cat(code_emb, dim=-1).permute(0,2,1)
|
||||||
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
|
||||||
|
|
||||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
|
@ -224,6 +234,7 @@ class FlatDiffusion(nn.Module):
|
||||||
code_emb)
|
code_emb)
|
||||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
||||||
expanded_code_emb = self.code_converter(expanded_code_emb)
|
expanded_code_emb = self.code_converter(expanded_code_emb)
|
||||||
|
expanded_code_emb = self.code_norm(expanded_code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
||||||
|
|
||||||
if not return_code_pred:
|
if not return_code_pred:
|
||||||
return expanded_code_emb
|
return expanded_code_emb
|
||||||
|
@ -313,7 +324,7 @@ if __name__ == '__main__':
|
||||||
aligned_sequence = torch.randint(0,8,(2,100,8))
|
aligned_sequence = torch.randint(0,8,(2,100,8))
|
||||||
cond = torch.randn(2, 256, 400)
|
cond = torch.randn(2, 256, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5)
|
model = FlatDiffusion(512, layer_drop=.3, unconditioned_percentage=.5, train_mel_head=True)
|
||||||
# Test with latent aligned conditioning
|
# Test with latent aligned conditioning
|
||||||
#o = model(clip, ts, aligned_latent, cond)
|
#o = model(clip, ts, aligned_latent, cond)
|
||||||
# Test with sequence aligned conditioning
|
# Test with sequence aligned conditioning
|
||||||
|
|
|
@ -25,7 +25,7 @@ def recover_codegroups(codes, groups):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0,
|
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0,
|
||||||
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True)
|
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True)
|
||||||
model.load_state_dict(torch.load("X:\\dlas\\experiments\\train_music_mel2vec\\models\\29000_generator_ema.pth"))
|
model.load_state_dict(torch.load("../experiments/m2v_music.pth"))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050)
|
wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user