forked from mrq/DL-Art-School
tfd12 with ar prior
This commit is contained in:
parent
3f10ce275b
commit
ff5c03b460
codes
|
@ -1,4 +1,5 @@
|
||||||
import itertools
|
import itertools
|
||||||
|
from time import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -99,6 +100,8 @@ class TransformerDiffusion(nn.Module):
|
||||||
ar_prior=False,
|
ar_prior=False,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
|
# Parameters for re-training head
|
||||||
|
freeze_except_code_converters=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -161,6 +164,16 @@ class TransformerDiffusion(nn.Module):
|
||||||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if freeze_except_code_converters:
|
||||||
|
for p in self.parameters():
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
p.requires_grad = False
|
||||||
|
for m in [self.input_converter and self.code_converter]:
|
||||||
|
for p in m.parameters():
|
||||||
|
del p.DO_NOT_TRAIN
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
|
||||||
self.debug_codes = {}
|
self.debug_codes = {}
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
@ -391,7 +404,7 @@ class TransformerDiffusionWithPretrainedVqvae(nn.Module):
|
||||||
'out': list(self.diff.out.parameters()),
|
'out': list(self.diff.out.parameters()),
|
||||||
'x_proj': list(self.diff.inp_block.parameters()),
|
'x_proj': list(self.diff.inp_block.parameters()),
|
||||||
'layers': list(self.diff.layers.parameters()),
|
'layers': list(self.diff.layers.parameters()),
|
||||||
'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()),
|
#'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()),
|
||||||
'time_embed': list(self.diff.time_embed.parameters()),
|
'time_embed': list(self.diff.time_embed.parameters()),
|
||||||
}
|
}
|
||||||
return groups
|
return groups
|
||||||
|
@ -534,7 +547,7 @@ def test_vqvae_model():
|
||||||
model = TransformerDiffusionWithPretrainedVqvae(in_channels=100, out_channels=200,
|
model = TransformerDiffusionWithPretrainedVqvae(in_channels=100, out_channels=200,
|
||||||
model_channels=1024, contraction_dim=512,
|
model_channels=1024, contraction_dim=512,
|
||||||
prenet_channels=1024, num_heads=8,
|
prenet_channels=1024, num_heads=8,
|
||||||
input_vec_dim=512, num_layers=12, prenet_layers=6,
|
input_vec_dim=512, num_layers=12, prenet_layers=6, ar_prior=True,
|
||||||
dropout=.1, vqargs= {
|
dropout=.1, vqargs= {
|
||||||
'positional_dims': 1, 'channels': 80,
|
'positional_dims': 1, 'channels': 80,
|
||||||
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||||
|
@ -549,6 +562,20 @@ def test_vqvae_model():
|
||||||
o = model(clip, ts, cond)
|
o = model(clip, ts, cond)
|
||||||
pg = model.get_grad_norm_parameter_groups()
|
pg = model.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
proj = torch.randn(2, 100, 512).cuda()
|
||||||
|
clip = clip.cuda()
|
||||||
|
ts = ts.cuda()
|
||||||
|
start = time()
|
||||||
|
model = model.cuda().eval()
|
||||||
|
model.diff.enable_fp16 = True
|
||||||
|
ti = model.diff.timestep_independent(proj, clip.shape[2])
|
||||||
|
for k in range(100):
|
||||||
|
model.diff(clip, ts, precomputed_code_embeddings=ti)
|
||||||
|
print(f"Elapsed: {time()-start}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def test_multi_vqvae_model():
|
def test_multi_vqvae_model():
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 400)
|
||||||
|
@ -556,7 +583,7 @@ def test_multi_vqvae_model():
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
|
|
||||||
# For music:
|
# For music:
|
||||||
model = TransformerDiffusionWithMultiPretrainedVqvae(in_channels=256, out_channels=200,
|
model = TransformerDiffusionWithMultiPretrainedVqvae(in_channels=256, out_channels=512,
|
||||||
model_channels=1024, contraction_dim=512,
|
model_channels=1024, contraction_dim=512,
|
||||||
prenet_channels=1024, num_heads=8,
|
prenet_channels=1024, num_heads=8,
|
||||||
input_vec_dim=2048, num_layers=12, prenet_layers=6,
|
input_vec_dim=2048, num_layers=12, prenet_layers=6,
|
||||||
|
@ -604,4 +631,4 @@ def test_ar_model():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_multi_vqvae_model()
|
test_vqvae_model()
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import random
|
import random
|
||||||
|
from time import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -320,9 +321,22 @@ if __name__ == '__main__':
|
||||||
aligned_sequence = torch.randint(0,8192,(2,100))
|
aligned_sequence = torch.randint(0,8192,(2,100))
|
||||||
cond = torch.randn(2, 100, 400)
|
cond = torch.randn(2, 100, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5, freeze_everything_except_autoregressive_inputs=True)
|
model = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
||||||
|
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=True, num_heads=16,
|
||||||
|
layer_drop=0, unconditioned_percentage=0)
|
||||||
# Test with latent aligned conditioning
|
# Test with latent aligned conditioning
|
||||||
#o = model(clip, ts, aligned_latent, cond)
|
#o = model(clip, ts, aligned_latent, cond)
|
||||||
# Test with sequence aligned conditioning
|
# Test with sequence aligned conditioning
|
||||||
o = model(clip, ts, aligned_sequence, cond)
|
#o = model(clip, ts, aligned_sequence, cond)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
proj = torch.randn(2, 100, 1024).cuda()
|
||||||
|
clip = clip.cuda()
|
||||||
|
ts = ts.cuda()
|
||||||
|
start = time()
|
||||||
|
model = model.cuda().eval()
|
||||||
|
ti = model.timestep_independent(proj, clip, clip.shape[2], False)
|
||||||
|
for k in range(100):
|
||||||
|
model(clip, ts, precomputed_aligned_embeddings=ti)
|
||||||
|
print(f"Elapsed: {time()-start}")
|
||||||
|
|
||||||
|
|
|
@ -339,7 +339,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_diffusion_tfd.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_unified_alignment.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user