tfd7
This commit is contained in:
parent
b2a83efe50
commit
2f4d990ad1
|
@ -1,10 +1,12 @@
|
||||||
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from models.arch_util import zero_module
|
from models.arch_util import zero_module
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, ceil_multiple
|
from utils.util import checkpoint, ceil_multiple, print_network
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
class Downsample(nn.Module):
|
||||||
|
@ -152,33 +154,37 @@ class MusicQuantizer(nn.Module):
|
||||||
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995,
|
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995,
|
||||||
codebook_size=16, codebook_groups=4):
|
codebook_size=16, codebook_groups=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if not isinstance(inner_dim, list):
|
||||||
|
inner_dim = [inner_dim // 2 ** x for x in range(down_steps+1)]
|
||||||
self.max_gumbel_temperature = max_gumbel_temperature
|
self.max_gumbel_temperature = max_gumbel_temperature
|
||||||
self.min_gumbel_temperature = min_gumbel_temperature
|
self.min_gumbel_temperature = min_gumbel_temperature
|
||||||
self.gumbel_temperature_decay = gumbel_temperature_decay
|
self.gumbel_temperature_decay = gumbel_temperature_decay
|
||||||
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim, codevector_dim=codevector_dim,
|
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim[0], codevector_dim=codevector_dim,
|
||||||
num_codevector_groups=codebook_groups,
|
num_codevector_groups=codebook_groups,
|
||||||
num_codevectors_per_group=codebook_size)
|
num_codevectors_per_group=codebook_size)
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
self.codebook_groups = codebook_groups
|
||||||
self.num_losses_record = []
|
self.num_losses_record = []
|
||||||
|
|
||||||
if down_steps == 0:
|
if down_steps == 0:
|
||||||
self.down = nn.Conv1d(inp_channels, inner_dim, kernel_size=3, padding=1)
|
self.down = nn.Conv1d(inp_channels, inner_dim[0], kernel_size=3, padding=1)
|
||||||
self.up = nn.Conv1d(inner_dim, inp_channels, kernel_size=3, padding=1)
|
self.up = nn.Conv1d(inner_dim[0], inp_channels, kernel_size=3, padding=1)
|
||||||
elif down_steps == 2:
|
elif down_steps == 2:
|
||||||
self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim//4, kernel_size=3, padding=1),
|
self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim[-1], kernel_size=3, padding=1),
|
||||||
Downsample(inner_dim//4, inner_dim//2),
|
Downsample(inner_dim[-1], inner_dim[-2]),
|
||||||
Downsample(inner_dim//2, inner_dim))
|
Downsample(inner_dim[-2], inner_dim[-3]))
|
||||||
self.up = nn.Sequential(Upsample(inner_dim, inner_dim//2),
|
self.up = nn.Sequential(Upsample(inner_dim[-3], inner_dim[-2]),
|
||||||
Upsample(inner_dim//2, inner_dim//4),
|
Upsample(inner_dim[-2], inner_dim[-1]),
|
||||||
nn.Conv1d(inner_dim//4, inp_channels, kernel_size=3, padding=1))
|
nn.Conv1d(inner_dim[-1], inp_channels, kernel_size=3, padding=1))
|
||||||
|
|
||||||
self.encoder = nn.Sequential(ResBlock(inner_dim),
|
self.encoder = nn.Sequential(ResBlock(inner_dim[0]),
|
||||||
ResBlock(inner_dim),
|
ResBlock(inner_dim[0]),
|
||||||
ResBlock(inner_dim))
|
ResBlock(inner_dim[0]))
|
||||||
self.enc_norm = nn.LayerNorm(inner_dim, eps=1e-5)
|
self.enc_norm = nn.LayerNorm(inner_dim[0], eps=1e-5)
|
||||||
self.decoder = nn.Sequential(nn.Conv1d(codevector_dim, inner_dim, kernel_size=3, padding=1),
|
self.decoder = nn.Sequential(nn.Conv1d(codevector_dim, inner_dim[0], kernel_size=3, padding=1),
|
||||||
ResBlock(inner_dim),
|
ResBlock(inner_dim[0]),
|
||||||
ResBlock(inner_dim),
|
ResBlock(inner_dim[0]),
|
||||||
ResBlock(inner_dim))
|
ResBlock(inner_dim[0]))
|
||||||
|
|
||||||
self.codes = torch.zeros((3000000,), dtype=torch.long)
|
self.codes = torch.zeros((3000000,), dtype=torch.long)
|
||||||
self.internal_step = 0
|
self.internal_step = 0
|
||||||
|
@ -210,7 +216,7 @@ class MusicQuantizer(nn.Module):
|
||||||
if return_decoder_latent:
|
if return_decoder_latent:
|
||||||
return h, diversity
|
return h, diversity
|
||||||
|
|
||||||
reconstructed = self.up(h)
|
reconstructed = self.up(h.float())
|
||||||
reconstructed = reconstructed[:, :, :orig_mel.shape[-1]]
|
reconstructed = reconstructed[:, :, :orig_mel.shape[-1]]
|
||||||
|
|
||||||
mse = F.mse_loss(reconstructed, orig_mel)
|
mse = F.mse_loss(reconstructed, orig_mel)
|
||||||
|
@ -219,7 +225,10 @@ class MusicQuantizer(nn.Module):
|
||||||
def log_codes(self, codes):
|
def log_codes(self, codes):
|
||||||
if self.internal_step % 5 == 0:
|
if self.internal_step % 5 == 0:
|
||||||
codes = torch.argmax(codes, dim=-1)
|
codes = torch.argmax(codes, dim=-1)
|
||||||
codes = codes[:,:,0] + codes[:,:,1] * 16 + codes[:,:,2] * 16 ** 2 + codes[:,:,3] * 16 ** 3
|
ccodes = codes[:,:,0]
|
||||||
|
for j in range(1,codes.shape[-1]):
|
||||||
|
ccodes += codes[:,:,j] * self.codebook_size ** j
|
||||||
|
codes = ccodes
|
||||||
codes = codes.flatten()
|
codes = codes.flatten()
|
||||||
l = codes.shape[0]
|
l = codes.shape[0]
|
||||||
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||||
|
@ -242,6 +251,7 @@ def register_music_quantizer(opt_net, opt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = MusicQuantizer()
|
model = MusicQuantizer(inner_dim=[1024,1024,512], codevector_dim=1024, codebook_size=512, codebook_groups=2)
|
||||||
|
print_network(model)
|
||||||
mel = torch.randn((2,256,782))
|
mel = torch.randn((2,256,782))
|
||||||
model(mel)
|
model(mel)
|
|
@ -60,6 +60,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prenet_channels=256,
|
prenet_channels=256,
|
||||||
|
prenet_layers=3,
|
||||||
model_channels=512,
|
model_channels=512,
|
||||||
block_channels=256,
|
block_channels=256,
|
||||||
num_layers=8,
|
num_layers=8,
|
||||||
|
@ -108,7 +109,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
self.input_converter = nn.Linear(input_vec_dim, prenet_channels)
|
self.input_converter = nn.Linear(input_vec_dim, prenet_channels)
|
||||||
self.code_converter = Encoder(
|
self.code_converter = Encoder(
|
||||||
dim=prenet_channels,
|
dim=prenet_channels,
|
||||||
depth=3,
|
depth=prenet_layers,
|
||||||
heads=prenet_heads,
|
heads=prenet_heads,
|
||||||
ff_dropout=dropout,
|
ff_dropout=dropout,
|
||||||
attn_dropout=dropout,
|
attn_dropout=dropout,
|
||||||
|
@ -205,7 +206,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
self.internal_step = 0
|
self.internal_step = 0
|
||||||
self.freeze_quantizer_until = freeze_quantizer_until
|
self.freeze_quantizer_until = freeze_quantizer_until
|
||||||
self.diff = TransformerDiffusion(**kwargs)
|
self.diff = TransformerDiffusion(**kwargs)
|
||||||
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=2048, codevector_dim=1024)
|
self.m2v = MusicQuantizer(inp_channels=256, inner_dim=[1024,1024,512], codevector_dim=1024, codebook_size=512, codebook_groups=2)
|
||||||
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
|
self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature
|
||||||
del self.m2v.up
|
del self.m2v.up
|
||||||
|
|
||||||
|
@ -270,14 +271,14 @@ if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 400)
|
||||||
cond = torch.randn(2, 256, 400)
|
cond = torch.randn(2, 256, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, input_vec_dim=2048, num_layers=16)
|
model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, input_vec_dim=1024, num_layers=16, prenet_layers=6)
|
||||||
|
|
||||||
#quant_weights = torch.load('X:\\dlas\\experiments\\train_music_quant\\models\\1000_generator.pth')
|
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant\\models\\18000_generator_ema.pth')
|
||||||
#diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth')
|
#diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth')
|
||||||
#model.m2v.load_state_dict(quant_weights, strict=False)
|
model.m2v.load_state_dict(quant_weights, strict=False)
|
||||||
#model.diff.load_state_dict(diff_weights)
|
#model.diff.load_state_dict(diff_weights)
|
||||||
|
|
||||||
#torch.save(model.state_dict(), 'sample.pth')
|
torch.save(model.state_dict(), 'sample.pth')
|
||||||
print_network(model)
|
print_network(model)
|
||||||
o = model(clip, ts, clip, cond)
|
o = model(clip, ts, clip, cond)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user