forked from mrq/DL-Art-School
mods to support cheater ar prior in tfd12
This commit is contained in:
parent
286918c581
commit
58f26b1900
|
@ -98,7 +98,6 @@ class TransformerDiffusion(nn.Module):
|
||||||
num_heads=4,
|
num_heads=4,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
ar_prior=False,
|
|
||||||
new_code_expansion=False,
|
new_code_expansion=False,
|
||||||
permute_codes=False,
|
permute_codes=False,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
|
@ -127,23 +126,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
linear(time_embed_dim, time_embed_dim),
|
linear(time_embed_dim, time_embed_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ar_prior = ar_prior
|
|
||||||
prenet_heads = prenet_channels//64
|
prenet_heads = prenet_channels//64
|
||||||
if ar_prior:
|
|
||||||
self.ar_input = nn.Linear(input_vec_dim, prenet_channels)
|
|
||||||
self.ar_prior_intg = Encoder(
|
|
||||||
dim=prenet_channels,
|
|
||||||
depth=prenet_layers,
|
|
||||||
heads=prenet_heads,
|
|
||||||
ff_dropout=dropout,
|
|
||||||
attn_dropout=dropout,
|
|
||||||
use_rmsnorm=True,
|
|
||||||
ff_glu=True,
|
|
||||||
rotary_pos_emb=True,
|
|
||||||
zero_init_branch_output=True,
|
|
||||||
ff_mult=1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.input_converter = nn.Linear(input_vec_dim, prenet_channels)
|
self.input_converter = nn.Linear(input_vec_dim, prenet_channels)
|
||||||
self.code_converter = Encoder(
|
self.code_converter = Encoder(
|
||||||
dim=prenet_channels,
|
dim=prenet_channels,
|
||||||
|
@ -173,12 +156,6 @@ class TransformerDiffusion(nn.Module):
|
||||||
for p in self.parameters():
|
for p in self.parameters():
|
||||||
p.DO_NOT_TRAIN = True
|
p.DO_NOT_TRAIN = True
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
if hasattr(self, 'ar_input'):
|
|
||||||
for m in [self.ar_input and self.ar_prior_intg]:
|
|
||||||
for p in m.parameters():
|
|
||||||
del p.DO_NOT_TRAIN
|
|
||||||
p.requires_grad = True
|
|
||||||
if hasattr(self, 'code_converter'):
|
|
||||||
for m in [self.code_converter and self.input_converter]:
|
for m in [self.code_converter and self.input_converter]:
|
||||||
for p in m.parameters():
|
for p in m.parameters():
|
||||||
del p.DO_NOT_TRAIN
|
del p.DO_NOT_TRAIN
|
||||||
|
@ -213,8 +190,8 @@ class TransformerDiffusion(nn.Module):
|
||||||
def timestep_independent(self, prior, expected_seq_len):
|
def timestep_independent(self, prior, expected_seq_len):
|
||||||
if self.new_code_expansion:
|
if self.new_code_expansion:
|
||||||
prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1)
|
prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1)
|
||||||
code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior)
|
code_emb = self.input_converter(prior)
|
||||||
code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb)
|
code_emb = self.code_converter(code_emb)
|
||||||
|
|
||||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
if self.training and self.unconditioned_percentage > 0:
|
if self.training and self.unconditioned_percentage > 0:
|
||||||
|
@ -350,47 +327,6 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
p.grad *= .2
|
p.grad *= .2
|
||||||
|
|
||||||
|
|
||||||
class TransformerDiffusionWithARPrior(nn.Module):
|
|
||||||
def __init__(self, freeze_diff=False, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.internal_step = 0
|
|
||||||
from models.audio.music.gpt_music import GptMusicLower
|
|
||||||
self.ar = GptMusicLower(dim=512, layers=12)
|
|
||||||
for p in self.ar.parameters():
|
|
||||||
p.DO_NOT_TRAIN = True
|
|
||||||
p.requires_grad = False
|
|
||||||
|
|
||||||
self.diff = TransformerDiffusion(ar_prior=True, **kwargs)
|
|
||||||
if freeze_diff:
|
|
||||||
for p in self.diff.parameters():
|
|
||||||
p.DO_NOT_TRAIN = True
|
|
||||||
p.requires_grad = False
|
|
||||||
for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()):
|
|
||||||
del p.DO_NOT_TRAIN
|
|
||||||
p.requires_grad = True
|
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
|
||||||
groups = {
|
|
||||||
'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])),
|
|
||||||
'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])),
|
|
||||||
'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),
|
|
||||||
'out': list(self.diff.out.parameters()),
|
|
||||||
'x_proj': list(self.diff.inp_block.parameters()),
|
|
||||||
'layers': list(self.diff.layers.parameters()),
|
|
||||||
'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()),
|
|
||||||
'time_embed': list(self.diff.time_embed.parameters()),
|
|
||||||
}
|
|
||||||
return groups
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False):
|
|
||||||
with torch.no_grad():
|
|
||||||
prior = self.ar(truth_mel, conditioning_input, return_latent=True)
|
|
||||||
|
|
||||||
diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free)
|
|
||||||
return diff
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDiffusionWithPretrainedVqvae(nn.Module):
|
class TransformerDiffusionWithPretrainedVqvae(nn.Module):
|
||||||
def __init__(self, vqargs, **kwargs):
|
def __init__(self, vqargs, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -592,11 +528,6 @@ def register_transformer_diffusion12_with_quantizer(opt_net, opt):
|
||||||
return TransformerDiffusionWithQuantizer(**opt_net['kwargs'])
|
return TransformerDiffusionWithQuantizer(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_transformer_diffusion12_with_ar_prior(opt_net, opt):
|
|
||||||
return TransformerDiffusionWithARPrior(**opt_net['kwargs'])
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt):
|
def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt):
|
||||||
return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs'])
|
return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs'])
|
||||||
|
@ -659,7 +590,7 @@ def test_vqvae_model():
|
||||||
model = TransformerDiffusionWithPretrainedVqvae(in_channels=100, out_channels=200,
|
model = TransformerDiffusionWithPretrainedVqvae(in_channels=100, out_channels=200,
|
||||||
model_channels=1024, contraction_dim=512,
|
model_channels=1024, contraction_dim=512,
|
||||||
prenet_channels=1024, num_heads=8,
|
prenet_channels=1024, num_heads=8,
|
||||||
input_vec_dim=512, num_layers=12, prenet_layers=6, ar_prior=True,
|
input_vec_dim=512, num_layers=12, prenet_layers=6,
|
||||||
dropout=.1, vqargs= {
|
dropout=.1, vqargs= {
|
||||||
'positional_dims': 1, 'channels': 80,
|
'positional_dims': 1, 'channels': 80,
|
||||||
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||||
|
@ -720,28 +651,6 @@ def test_multi_vqvae_model():
|
||||||
model.diff.get_grad_norm_parameter_groups()
|
model.diff.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
|
|
||||||
def test_ar_model():
|
|
||||||
clip = torch.randn(2, 256, 400)
|
|
||||||
cond = torch.randn(2, 256, 400)
|
|
||||||
ts = torch.LongTensor([600, 600])
|
|
||||||
model = TransformerDiffusionWithARPrior(model_channels=2048, prenet_channels=1536,
|
|
||||||
input_vec_dim=512, num_layers=16, prenet_layers=6, freeze_diff=True,
|
|
||||||
unconditioned_percentage=.4)
|
|
||||||
model.get_grad_norm_parameter_groups()
|
|
||||||
|
|
||||||
ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth')
|
|
||||||
model.ar.load_state_dict(ar_weights, strict=True)
|
|
||||||
diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth')
|
|
||||||
pruned_diff_weights = {}
|
|
||||||
for k,v in diff_weights.items():
|
|
||||||
if k.startswith('diff.'):
|
|
||||||
pruned_diff_weights[k.replace('diff.', '')] = v
|
|
||||||
model.diff.load_state_dict(pruned_diff_weights, strict=False)
|
|
||||||
torch.save(model.state_dict(), 'sample.pth')
|
|
||||||
|
|
||||||
model(clip, ts, cond, conditioning_input=cond)
|
|
||||||
|
|
||||||
|
|
||||||
def test_cheater_model():
|
def test_cheater_model():
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
|
@ -776,4 +685,5 @@ def extract_diff(in_f, out_f, remove_head=False):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True)
|
#extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True)
|
||||||
test_cheater_model()
|
#test_cheater_model()
|
||||||
|
extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True)
|
||||||
|
|
|
@ -13,7 +13,6 @@ from trainer.injectors.audio_injectors import MusicCheaterLatentInjector
|
||||||
from models.diffusion.respace import SpacedDiffusion
|
from models.diffusion.respace import SpacedDiffusion
|
||||||
from models.diffusion.respace import space_timesteps
|
from models.diffusion.respace import space_timesteps
|
||||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||||
from models.audio.music.transformer_diffusion12 import TransformerDiffusionWithCheaterLatent
|
|
||||||
|
|
||||||
|
|
||||||
def join_music(clip1, clip1_cut, clip2, clip2_cut, mix_time, results_dir):
|
def join_music(clip1, clip1_cut, clip2, clip2_cut, mix_time, results_dir):
|
||||||
|
|
|
@ -339,7 +339,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_ar_cheater_gen.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_tfd12_finetune_ar_outputs.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
|
from models.audio.music.cheater_gen_ar import ConditioningAR
|
||||||
from trainer.inject import Injector
|
from trainer.inject import Injector
|
||||||
from utils.music_utils import get_music_codegen
|
from utils.music_utils import get_music_codegen
|
||||||
from utils.util import opt_get, load_model_from_config, pad_or_truncate
|
from utils.util import opt_get, load_model_from_config, pad_or_truncate
|
||||||
|
@ -426,3 +427,22 @@ class KmeansQuantizerInjector(Injector):
|
||||||
distances = distances.reshape(b, s, self.centroids.shape[-1])
|
distances = distances.reshape(b, s, self.centroids.shape[-1])
|
||||||
labels = distances.argmin(-1)
|
labels = distances.argmin(-1)
|
||||||
return {self.output: labels}
|
return {self.output: labels}
|
||||||
|
|
||||||
|
|
||||||
|
class MusicCheaterArInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.cheater_ar = ConditioningAR(1024, layers=24, dropout=0, cond_free_percent=0)
|
||||||
|
self.cheater_ar.load_state_dict(torch.load('../experiments/music_cheater_ar.pth', map_location=torch.device('cpu')))
|
||||||
|
self.cond_key = opt['cheater_latent_key']
|
||||||
|
self.needs_move = True
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
codes = state[self.input]
|
||||||
|
cond = state[self.cond_key]
|
||||||
|
if self.needs_move:
|
||||||
|
self.cheater_ar = self.cheater_ar.to(codes.device)
|
||||||
|
self.needs_move = False
|
||||||
|
with torch.no_grad():
|
||||||
|
latents = self.cheater_ar(codes, cond, return_latent=True)
|
||||||
|
return {self.output: latents}
|
Loading…
Reference in New Issue
Block a user