unet with ar prior
This commit is contained in:
parent
5028703b3d
commit
5a54d7db11
|
@ -6,6 +6,7 @@ from transformers import GPT2Config, GPT2Model
|
|||
from models.arch_util import AttentionBlock
|
||||
from models.audio.music.music_quantizer import MusicQuantizer
|
||||
from models.audio.music.music_quantizer2 import MusicQuantizer2
|
||||
from models.lucidrains.x_transformers import Encoder
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
@ -31,34 +32,55 @@ class ConditioningEncoder(nn.Module):
|
|||
|
||||
|
||||
class GptMusicLower(nn.Module):
|
||||
def __init__(self, dim, layers, num_target_vectors=512, num_target_groups=2, cv_dim=1024, num_upper_vectors=64, num_upper_groups=4):
|
||||
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4):
|
||||
super().__init__()
|
||||
self.internal_step = 0
|
||||
self.num_groups = num_target_groups
|
||||
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||||
n_inner=dim*2)
|
||||
self.target_quantizer = MusicQuantizer(inp_channels=256, inner_dim=[1024,1024,512], codevector_dim=cv_dim, codebook_size=num_target_vectors, codebook_groups=num_target_groups)
|
||||
self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024,896,768,640,512,384], codevector_dim=cv_dim, codebook_size=num_upper_vectors, codebook_groups=num_upper_groups)
|
||||
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False)
|
||||
self.target_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024], codevector_dim=1024, codebook_size=256,
|
||||
codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5)
|
||||
self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[dim,
|
||||
max(512,dim-128),
|
||||
max(512,dim-256),
|
||||
max(512,dim-384),
|
||||
max(512,dim-512),
|
||||
max(512,dim-512)], codevector_dim=dim,
|
||||
codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True)
|
||||
# Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..)
|
||||
del self.target_quantizer.decoder
|
||||
del self.target_quantizer.up
|
||||
del self.upper_quantizer.up
|
||||
# Freeze the target quantizer.
|
||||
for p in self.target_quantizer.parameters():
|
||||
p.DO_NOT_TRAIN = True
|
||||
p.requires_grad = False
|
||||
|
||||
self.upper_mixer = Encoder(
|
||||
dim=dim,
|
||||
depth=4,
|
||||
heads=dim//64,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_emb_dim=True,
|
||||
)
|
||||
self.conditioning_encoder = ConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64)
|
||||
|
||||
self.gpt = GPT2Model(self.config)
|
||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_target_groups) for _ in range(num_target_groups)])
|
||||
self.upper_proj = nn.Conv1d(cv_dim, dim, kernel_size=1)
|
||||
self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_target_groups)])
|
||||
|
||||
|
||||
def forward(self, mel, conditioning):
|
||||
def forward(self, mel, conditioning, return_latent=False):
|
||||
with torch.no_grad():
|
||||
self.target_quantizer.eval()
|
||||
codes = self.target_quantizer.get_codes(mel)
|
||||
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
|
||||
upper_vector = self.upper_proj(upper_vector)
|
||||
upper_vector = self.upper_mixer(upper_vector.permute(0,2,1)).permute(0,2,1) # Allow the upper vector to fully attend to itself (the whole thing is a prior.)
|
||||
upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear')
|
||||
upper_vector = upper_vector.permute(0,2,1)
|
||||
|
||||
|
@ -68,21 +90,49 @@ class GptMusicLower(nn.Module):
|
|||
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||
h = torch.cat(h, dim=-1) + upper_vector
|
||||
|
||||
# Stick the conditioning embedding on the front of the input sequence.
|
||||
# The transformer will learn how to integrate it.
|
||||
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
||||
cond_emb = self.conditioning_encoder(conditioning).unsqueeze(1)
|
||||
h = torch.cat([cond_emb, h], dim=1)
|
||||
with torch.autocast(mel.device.type):
|
||||
# Stick the conditioning embedding on the front of the input sequence.
|
||||
# The transformer will learn how to integrate it.
|
||||
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
||||
cond_emb = self.conditioning_encoder(conditioning).unsqueeze(1)
|
||||
h = torch.cat([cond_emb, h], dim=1)
|
||||
|
||||
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
||||
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
||||
|
||||
losses = 0
|
||||
for i, head in enumerate(self.heads):
|
||||
logits = head(h).permute(0,2,1)
|
||||
loss = F.cross_entropy(logits, targets[:,:,i])
|
||||
losses = losses + loss
|
||||
if return_latent:
|
||||
return h.float()
|
||||
|
||||
return losses / self.num_groups
|
||||
losses = 0
|
||||
for i, head in enumerate(self.heads):
|
||||
logits = head(h).permute(0,2,1)
|
||||
loss = F.cross_entropy(logits, targets[:,:,i])
|
||||
losses = losses + loss
|
||||
|
||||
return losses / self.num_groups, upper_diversity
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
groups = {
|
||||
'gpt': list(self.gpt.parameters()),
|
||||
'conditioning': list(self.conditioning_encoder.parameters()),
|
||||
'upper_mixer': list(self.upper_mixer.parameters()),
|
||||
'upper_quant_down': list(self.upper_quantizer.down.parameters()),
|
||||
'upper_quant_encoder': list(self.upper_quantizer.encoder.parameters()),
|
||||
'upper_quant_codebook': [self.upper_quantizer.quantizer.codevectors],
|
||||
}
|
||||
return groups
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
if self.upper_quantizer.total_codes > 0:
|
||||
return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes]}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def update_for_step(self, step, *args):
|
||||
self.internal_step = step
|
||||
self.upper_quantizer.quantizer.temperature = max(
|
||||
self.upper_quantizer.max_gumbel_temperature * self.upper_quantizer.gumbel_temperature_decay**self.internal_step,
|
||||
self.upper_quantizer.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -91,6 +141,15 @@ def register_music_gpt_lower(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from models.audio.music.transformer_diffusion8 import TransformerDiffusionWithQuantizer
|
||||
base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024,
|
||||
prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024,
|
||||
dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
|
||||
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/28000_generator.pth', map_location=torch.device('cpu')))
|
||||
|
||||
model = GptMusicLower(512, 12)
|
||||
model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
|
||||
torch.save(model.state_dict(), "sample.pth")
|
||||
mel = torch.randn(2,256,400)
|
||||
model(mel, mel)
|
||||
model(mel, mel)
|
||||
model.get_grad_norm_parameter_groups()
|
|
@ -11,13 +11,25 @@ from utils.util import checkpoint, ceil_multiple, print_network
|
|||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, chan_in, chan_out):
|
||||
def __init__(self, chan_in, chan_out, norm=False, act=False, stride_down=False):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1)
|
||||
self.interpolate = not stride_down
|
||||
if stride_down:
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1, stride=2)
|
||||
else:
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1)
|
||||
if norm:
|
||||
self.norm = nn.GroupNorm(8, chan_out)
|
||||
self.act = act
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=.5, mode='linear')
|
||||
if self.interpolate:
|
||||
x = F.interpolate(x, scale_factor=.5, mode='linear')
|
||||
x = self.conv(x)
|
||||
if hasattr(self, 'norm'):
|
||||
x = self.norm(x)
|
||||
if self.act:
|
||||
x = F.silu(x, inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -153,7 +165,9 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||
class MusicQuantizer2(nn.Module):
|
||||
def __init__(self, inp_channels=256, inner_dim=1024, codevector_dim=1024, down_steps=2,
|
||||
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995,
|
||||
codebook_size=16, codebook_groups=4):
|
||||
codebook_size=16, codebook_groups=4,
|
||||
# Downsample args:
|
||||
expressive_downsamples=False):
|
||||
super().__init__()
|
||||
if not isinstance(inner_dim, list):
|
||||
inner_dim = [inner_dim // 2 ** x for x in range(down_steps+1)]
|
||||
|
@ -172,7 +186,8 @@ class MusicQuantizer2(nn.Module):
|
|||
self.up = nn.Conv1d(inner_dim[0], inp_channels, kernel_size=3, padding=1)
|
||||
elif down_steps == 2:
|
||||
self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim[-1], kernel_size=3, padding=1),
|
||||
*[Downsample(inner_dim[-i], inner_dim[-i-1]) for i in range(1,len(inner_dim))])
|
||||
*[Downsample(inner_dim[-i], inner_dim[-i-1], norm=expressive_downsamples, act=expressive_downsamples,
|
||||
stride_down=expressive_downsamples) for i in range(1,len(inner_dim))])
|
||||
self.up = nn.Sequential(*[Upsample(inner_dim[i], inner_dim[i+1]) for i in range(len(inner_dim)-1)] +
|
||||
[nn.Conv1d(inner_dim[-1], inp_channels, kernel_size=3, padding=1)])
|
||||
|
||||
|
@ -190,14 +205,11 @@ class MusicQuantizer2(nn.Module):
|
|||
self.code_ind = 0
|
||||
self.total_codes = 0
|
||||
|
||||
def get_codes(self, mel, project=False):
|
||||
proj = self.m2v.input_blocks(mel).permute(0,2,1)
|
||||
_, proj = self.m2v.projector(proj)
|
||||
if project:
|
||||
proj, _ = self.quantizer(proj)
|
||||
return proj
|
||||
else:
|
||||
return self.quantizer.get_codes(proj)
|
||||
def get_codes(self, mel):
|
||||
h = self.down(mel)
|
||||
h = self.encoder(h)
|
||||
h = self.enc_norm(h.permute(0,2,1))
|
||||
return self.quantizer.get_codes(h)
|
||||
|
||||
def forward(self, mel, return_decoder_latent=False):
|
||||
orig_mel = mel
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch.nn.functional as F
|
|||
import torchvision # For debugging, not actually used.
|
||||
from x_transformers.x_transformers import RelativePositionBias
|
||||
|
||||
from models.audio.music.gpt_music import GptMusicLower
|
||||
from models.audio.music.music_quantizer import MusicQuantizer
|
||||
from models.diffusion.fp16_util import convert_module_to_f16, convert_module_to_f32
|
||||
from models.diffusion.nn import (
|
||||
|
@ -451,6 +452,7 @@ class UNetMusicModel(nn.Module):
|
|||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
ar_prior=False,
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
|
@ -483,6 +485,7 @@ class UNetMusicModel(nn.Module):
|
|||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.unconditioned_percentage = unconditioned_percentage
|
||||
self.ar_prior = ar_prior
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
|
@ -491,8 +494,9 @@ class UNetMusicModel(nn.Module):
|
|||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_converter = nn.Linear(input_vec_dim, model_channels)
|
||||
self.code_converter = Encoder(
|
||||
if self.ar_prior:
|
||||
self.ar_input = nn.Linear(input_vec_dim, model_channels)
|
||||
self.ar_prior_intg = Encoder(
|
||||
dim=model_channels,
|
||||
depth=4,
|
||||
heads=num_heads,
|
||||
|
@ -504,6 +508,20 @@ class UNetMusicModel(nn.Module):
|
|||
zero_init_branch_output=True,
|
||||
ff_mult=1,
|
||||
)
|
||||
else:
|
||||
self.input_converter = nn.Linear(input_vec_dim, model_channels)
|
||||
self.code_converter = Encoder(
|
||||
dim=model_channels,
|
||||
depth=4,
|
||||
heads=num_heads,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_output=True,
|
||||
ff_mult=1,
|
||||
)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
||||
self.x_processor = conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
|
||||
|
@ -659,15 +677,18 @@ class UNetMusicModel(nn.Module):
|
|||
|
||||
if conditioning_free:
|
||||
expanded_code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1).permute(0,2,1)
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.input_converter.parameters()))
|
||||
if self.ar_prior:
|
||||
unused_params.extend(list(self.ar_input.parameters()) + list(self.ar_prior_intg.parameters()))
|
||||
else:
|
||||
unused_params.extend(list(self.input_converter.parameters()) + list(self.code_converter.parameters()))
|
||||
else:
|
||||
code_emb = self.input_converter(y)
|
||||
code_emb = self.ar_input(y) if self.ar_prior else self.input_converter(y)
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(y.shape[0], 1, 1),
|
||||
code_emb)
|
||||
code_emb = self.code_converter(code_emb)
|
||||
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=x.shape[-1], mode='nearest')
|
||||
|
||||
h = x.type(self.dtype)
|
||||
|
@ -740,23 +761,60 @@ class UNetMusicModelWithQuantizer(nn.Module):
|
|||
return {}
|
||||
|
||||
|
||||
class UNetMusicModelARPrior(nn.Module):
|
||||
def __init__(self, freeze_unet=False, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.internal_step = 0
|
||||
self.ar = GptMusicLower(dim=512, layers=12)
|
||||
for p in self.ar.parameters():
|
||||
p.DO_NOT_TRAIN = True
|
||||
p.requires_grad = False
|
||||
|
||||
self.diff = UNetMusicModel(ar_prior=True, **kwargs)
|
||||
if freeze_unet:
|
||||
for p in self.diff.parameters():
|
||||
p.DO_NOT_TRAIN = True
|
||||
p.requires_grad = False
|
||||
for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()):
|
||||
del p.DO_NOT_TRAIN
|
||||
p.requires_grad = True
|
||||
|
||||
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
|
||||
with torch.no_grad():
|
||||
prior = self.ar(truth_mel, conditioning_input, return_latent=True)
|
||||
|
||||
diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free)
|
||||
return diff
|
||||
|
||||
|
||||
@register_model
|
||||
def register_unet_diffusion_music_codes(opt_net, opt):
|
||||
return UNetMusicModelWithQuantizer(**opt_net['args'])
|
||||
|
||||
@register_model
|
||||
def register_unet_diffusion_music_ar_prior(opt_net, opt):
|
||||
return UNetMusicModelARPrior(**opt_net['args'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 256, 782)
|
||||
cond = torch.randn(2, 256, 782)
|
||||
clip = torch.randn(2, 256, 300)
|
||||
cond = torch.randn(2, 256, 300)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=1024, num_res_blocks=3, input_vec_dim=1024,
|
||||
attention_resolutions=(2,4), channel_mult=(1,1.5,2), dims=1,
|
||||
use_scale_shift_norm=True, dropout=.1, num_heads=16, unconditioned_percentage=.4)
|
||||
model = UNetMusicModelARPrior(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=512,
|
||||
attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1,
|
||||
use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4, freeze_unet=True)
|
||||
print_network(model)
|
||||
|
||||
quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant\\models\\18000_generator_ema.pth')
|
||||
model.m2v.load_state_dict(quant_weights, strict=False)
|
||||
ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
|
||||
model.ar.load_state_dict(ar_weights, strict=True)
|
||||
diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_unet_music\\models\\55500_generator_ema.pth')
|
||||
pruned_diff_weights = {}
|
||||
for k,v in diff_weights.items():
|
||||
if k.startswith('diff.'):
|
||||
pruned_diff_weights[k.replace('diff.', '')] = v
|
||||
model.diff.load_state_dict(pruned_diff_weights, strict=False)
|
||||
torch.save(model.state_dict(), 'sample.pth')
|
||||
|
||||
model(clip, ts, cond)
|
||||
model(clip, ts, cond, cond)
|
||||
|
||||
|
|
|
@ -339,7 +339,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_quant.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -201,13 +201,13 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\train_music_diffusion_tfd5_quant.yml', 'generator',
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_quant7.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd5_quant\\models\\40500_generator_ema.pth'
|
||||
load_path='X:\\dlas\\experiments\\train_music_diffusion_unet_music\\models\\46500_generator_ema.pth'
|
||||
).cuda()
|
||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 100,
|
||||
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'from_codes_quant'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 560, 'device': 'cuda', 'opt': {}}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 561, 'device': 'cuda', 'opt': {}}
|
||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
Loading…
Reference in New Issue
Block a user