diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index 73a8b3e3..145a7260 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -196,13 +196,14 @@ class TransformerDiffusion(nn.Module): class TransformerDiffusionWithQuantizer(nn.Module): - def __init__(self, train_quantizer_reconstruction_until=-1, freeze_quantizer_until=10000, **kwargs): + def __init__(self, quantizer_dims=[1024], train_quantizer_reconstruction_until=-1, freeze_quantizer_until=10000, **kwargs): super().__init__() self.internal_step = 0 self.freeze_quantizer_until = freeze_quantizer_until self.train_quantizer_reconstruction_until = train_quantizer_reconstruction_until self.diff = TransformerDiffusion(**kwargs) - self.quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024], codevector_dim=1024, codebook_size=256, + self.quantizer = MusicQuantizer2(inp_channels=kwargs['in_channels'], inner_dim=quantizer_dims, + codevector_dim=quantizer_dims[0], codebook_size=256, codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature if train_quantizer_reconstruction_until == -1: @@ -327,11 +328,11 @@ def register_transformer_diffusion8_with_ar_prior(opt_net, opt): def test_quant_model(): - clip = torch.randn(2, 256, 400) + clip = torch.randn(2, 100, 401) ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, - input_vec_dim=1024, num_layers=16, prenet_layers=6, - train_quantizer_reconstruction_until=1000) + model = TransformerDiffusionWithQuantizer(in_channels=100, model_channels=2048, block_channels=1024, prenet_channels=1024, + input_vec_dim=1024, num_layers=16, prenet_layers=6, quantizer_dims=[1024,896,768,512], + train_quantizer_reconstruction_until=-1) model.get_grad_norm_parameter_groups() print_network(model)