From df0cdf1a4f6d54cfc7fcbdd04ff3a99ac58dc83c Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 11 Jun 2022 08:00:09 -0600 Subject: [PATCH] tfd9 returns with some optimizations --- .../audio/music/transformer_diffusion9.py | 361 ++++++++++++++++++ .../models/audio/tts/unet_diffusion_tts10.py | 330 ---------------- codes/models/lucidrains/x_transformers.py | 4 +- 3 files changed, 363 insertions(+), 332 deletions(-) create mode 100644 codes/models/audio/music/transformer_diffusion9.py delete mode 100644 codes/models/audio/tts/unet_diffusion_tts10.py diff --git a/codes/models/audio/music/transformer_diffusion9.py b/codes/models/audio/music/transformer_diffusion9.py new file mode 100644 index 00000000..b2068d28 --- /dev/null +++ b/codes/models/audio/music/transformer_diffusion9.py @@ -0,0 +1,361 @@ +import itertools + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.arch_util import ResBlock +from models.audio.music.music_quantizer2 import MusicQuantizer2 +from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear +from models.diffusion.unet_diffusion import TimestepBlock +from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding +from trainer.networks import register_model +from utils.util import checkpoint, print_network + + +def is_latent(t): + return t.dtype == torch.float + +def is_sequence(t): + return t.dtype == torch.long + + +class MultiGroupEmbedding(nn.Module): + def __init__(self, tokens, groups, dim): + super().__init__() + self.m = nn.ModuleList([nn.Embedding(tokens, dim // groups) for _ in range(groups)]) + + def forward(self, x): + h = [embedding(x[:, :, i]) for i, embedding in enumerate(self.m)] + return torch.cat(h, dim=-1) + + +class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): + def forward(self, x, emb, rotary_emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb, rotary_emb) + else: + x = layer(x, rotary_emb) + return x + + +class DietAttentionBlock(TimestepBlock): + def __init__(self, in_dim, dim, heads, dropout): + super().__init__() + self.proj = nn.Linear(in_dim, dim) + self.proj.bias.data.zero_() + self.rms_scale_norm = RMSScaleShiftNorm(dim, bias=False) + self.attn = Attention(dim, heads=heads, dim_head=dim//heads, causal=False, dropout=dropout) + self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) + + def forward(self, x, timestep_emb, rotary_emb): + h = self.proj(x) + h = self.rms_scale_norm(h, norm_scale_shift_inp=timestep_emb) + h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) + h = checkpoint(self.ff, h) + return h + x + + +class TransformerDiffusion(nn.Module): + """ + A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way? + """ + def __init__( + self, + prenet_channels=256, + prenet_layers=3, + model_channels=512, + block_channels=256, + num_layers=8, + in_channels=256, + rotary_emb_dim=32, + input_vec_dim=512, + out_channels=512, # mean and variance + num_heads=16, + dropout=0, + use_fp16=False, + ar_prior=False, + # Parameters for regularization. + unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.prenet_channels = prenet_channels + self.out_channels = out_channels + self.dropout = dropout + self.unconditioned_percentage = unconditioned_percentage + self.enable_fp16 = use_fp16 + + self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) + + self.time_embed = nn.Sequential( + linear(prenet_channels, prenet_channels), + nn.SiLU(), + linear(prenet_channels, block_channels), + ) + + self.ar_prior = ar_prior + prenet_heads = prenet_channels//64 + if ar_prior: + self.ar_input = nn.Linear(input_vec_dim, prenet_channels) + self.ar_prior_intg = Encoder( + dim=prenet_channels, + depth=prenet_layers, + heads=prenet_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, + ) + else: + self.input_converter = nn.Linear(input_vec_dim, prenet_channels) + self.code_converter = Encoder( + dim=prenet_channels, + depth=prenet_layers, + heads=prenet_heads, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + zero_init_branch_output=True, + ff_mult=1, + ) + + self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) + self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) + self.intg = nn.Linear(prenet_channels*2, model_channels) + self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, num_heads, dropout) for _ in range(num_layers)]) + + self.out = nn.Sequential( + normalization(model_channels), + nn.SiLU(), + zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), + ) + + self.debug_codes = {} + + def get_grad_norm_parameter_groups(self): + groups = { + 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), + 'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()), + 'time_embed': list(self.time_embed.parameters()), + } + return groups + + def timestep_independent(self, prior, expected_seq_len): + code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) + code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) + + # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. + if self.training and self.unconditioned_percentage > 0: + unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), + device=code_emb.device) < self.unconditioned_percentage + code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1), + code_emb) + + expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) + return expanded_code_emb + + def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False): + if precomputed_code_embeddings is not None: + assert codes is None and conditioning_input is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." + + unused_params = [] + if conditioning_free: + code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) + else: + if precomputed_code_embeddings is not None: + code_emb = precomputed_code_embeddings + else: + code_emb = self.timestep_independent(codes, x.shape[-1]) + unused_params.append(self.unconditioned_embedding) + + with torch.autocast(x.device.type, enabled=self.enable_fp16): + blk_emb = self.time_embed(timestep_embedding(timesteps, self.prenet_channels)) + x = self.inp_block(x).permute(0,2,1) + + rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) + x = self.intg(torch.cat([x, code_emb], dim=-1)) + for layer in self.layers: + x = checkpoint(layer, x, blk_emb, rotary_pos_emb) + + x = x.float().permute(0,2,1) + out = self.out(x) + + # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. + extraneous_addition = 0 + for p in unused_params: + extraneous_addition = extraneous_addition + p.mean() + out = out + extraneous_addition * 0 + + return out + + +class TransformerDiffusionWithQuantizer(nn.Module): + def __init__(self, quantizer_dims=[1024], freeze_quantizer_until=20000, **kwargs): + super().__init__() + + self.internal_step = 0 + self.freeze_quantizer_until = freeze_quantizer_until + self.diff = TransformerDiffusion(**kwargs) + self.quantizer = MusicQuantizer2(inp_channels=kwargs['in_channels'], inner_dim=quantizer_dims, + codevector_dim=quantizer_dims[0], codebook_size=256, + codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) + self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature + del self.quantizer.up + + def update_for_step(self, step, *args): + self.internal_step = step + qstep = max(0, self.internal_step - self.freeze_quantizer_until) + self.quantizer.quantizer.temperature = max( + self.quantizer.max_gumbel_temperature * self.quantizer.gumbel_temperature_decay ** qstep, + self.quantizer.min_gumbel_temperature, + ) + + def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): + quant_grad_enabled = self.internal_step > self.freeze_quantizer_until + with torch.set_grad_enabled(quant_grad_enabled): + proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True) + proj = proj.permute(0,2,1) + + # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. + if not quant_grad_enabled: + unused = 0 + for p in self.quantizer.parameters(): + unused = unused + p.mean() * 0 + proj = proj + unused + diversity_loss = diversity_loss * 0 + + diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) + if disable_diversity: + return diff + return diff, diversity_loss + + def get_debug_values(self, step, __): + if self.quantizer.total_codes > 0: + return {'histogram_quant_codes': self.quantizer.codes[:self.quantizer.total_codes], + 'gumbel_temperature': self.quantizer.quantizer.temperature} + else: + return {} + + def get_grad_norm_parameter_groups(self): + groups = { + 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), + 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), + 'quantizer_encoder': list(self.quantizer.encoder.parameters()), + 'quant_codebook': [self.quantizer.quantizer.codevectors], + 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), + 'out': list(self.diff.out.parameters()), + 'x_proj': list(self.diff.inp_block.parameters()), + 'layers': list(self.diff.layers.parameters()), + 'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()), + 'time_embed': list(self.diff.time_embed.parameters()), + } + return groups + + +class TransformerDiffusionWithARPrior(nn.Module): + def __init__(self, freeze_diff=False, **kwargs): + super().__init__() + + self.internal_step = 0 + from models.audio.music.gpt_music import GptMusicLower + self.ar = GptMusicLower(dim=512, layers=12) + for p in self.ar.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + + self.diff = TransformerDiffusion(ar_prior=True, **kwargs) + if freeze_diff: + for p in self.diff.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + for p in list(self.diff.ar_prior_intg.parameters()) + list(self.diff.ar_input.parameters()): + del p.DO_NOT_TRAIN + p.requires_grad = True + + def get_grad_norm_parameter_groups(self): + groups = { + 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), + 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), + 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), + 'out': list(self.diff.out.parameters()), + 'x_proj': list(self.diff.inp_block.parameters()), + 'layers': list(self.diff.layers.parameters()), + 'ar_prior_intg': list(self.diff.ar_prior_intg.parameters()), + 'time_embed': list(self.diff.time_embed.parameters()), + } + return groups + + def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): + with torch.no_grad(): + prior = self.ar(truth_mel, conditioning_input, return_latent=True) + + diff = self.diff(x, timesteps, prior, conditioning_free=conditioning_free) + return diff + + +@register_model +def register_transformer_diffusion9(opt_net, opt): + return TransformerDiffusion(**opt_net['kwargs']) + + +@register_model +def register_transformer_diffusion9_with_quantizer(opt_net, opt): + return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) + + +@register_model +def register_transformer_diffusion9_with_ar_prior(opt_net, opt): + return TransformerDiffusionWithARPrior(**opt_net['kwargs']) + + +def test_quant_model(): + clip = torch.randn(2, 256, 400) + cond = torch.randn(2, 256, 400) + ts = torch.LongTensor([600, 600]) + model = TransformerDiffusionWithQuantizer(in_channels=256, model_channels=3072, block_channels=1536, + prenet_channels=1024, num_heads=12, + input_vec_dim=1024, num_layers=24, prenet_layers=6) + model.get_grad_norm_parameter_groups() + + quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant_r4\\models\\5000_generator.pth') + model.quantizer.load_state_dict(quant_weights, strict=False) + + torch.save(model.state_dict(), 'sample.pth') + print_network(model) + o = model(clip, ts, clip, cond) + + +def test_ar_model(): + clip = torch.randn(2, 256, 400) + cond = torch.randn(2, 256, 400) + ts = torch.LongTensor([600, 600]) + model = TransformerDiffusionWithARPrior(model_channels=3072, block_channels=1536, prenet_channels=1536, + input_vec_dim=512, num_layers=24, prenet_layers=6, freeze_diff=True, + unconditioned_percentage=.4) + model.get_grad_norm_parameter_groups() + + ar_weights = torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth') + model.ar.load_state_dict(ar_weights, strict=True) + diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd8\\models\\47500_generator_ema.pth') + pruned_diff_weights = {} + for k,v in diff_weights.items(): + if k.startswith('diff.'): + pruned_diff_weights[k.replace('diff.', '')] = v + model.diff.load_state_dict(pruned_diff_weights, strict=False) + torch.save(model.state_dict(), 'sample.pth') + + model(clip, ts, cond, conditioning_input=cond) + + + +if __name__ == '__main__': + test_quant_model() diff --git a/codes/models/audio/tts/unet_diffusion_tts10.py b/codes/models/audio/tts/unet_diffusion_tts10.py deleted file mode 100644 index 1bbc4f4b..00000000 --- a/codes/models/audio/tts/unet_diffusion_tts10.py +++ /dev/null @@ -1,330 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import autocast - -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ - Downsample, Upsample, TimestepBlock -from models.lucidrains.x_transformers import Encoder -from scripts.audio.gen.use_diffuse_tts import ceil_multiple -from trainer.networks import register_model -from utils.util import checkpoint - - -class ResBlock(TimestepBlock): - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - dims=2, - kernel_size=3, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - padding = 1 if kernel_size == 3 else 2 - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 1, padding=0), - ) - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, x, emb - ) - - def _forward(self, x, emb): - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class DiffusionTts(nn.Module): - def __init__( - self, - model_channels, - in_channels=100, - num_tokens=256, - out_channels=200, # mean and variance - dropout=0, - # m 1, 2, 4, 8 - block_channels= (512,640, 768,1024), - num_res_blocks= (3, 3, 3, 3), - token_conditioning_resolutions=(2,4,8), - attention_resolutions=(2,4,8), - conv_resample=True, - dims=1, - use_fp16=False, - kernel_size=3, - scale_factor=2, - num_heads=None, - time_embed_dim_multiplier=4, - nil_guidance_fwd_proportion=.15, - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.conv_resample = conv_resample - self.dtype = torch.float16 if use_fp16 else torch.float32 - self.dims = dims - self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion - self.mask_token_id = num_tokens - num_heads = model_channels // 64 if num_heads is None else num_heads - - padding = 1 if kernel_size == 3 else 2 - - time_embed_dim = model_channels * time_embed_dim_multiplier - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.code_embedding = nn.Embedding(num_tokens+1, model_channels) - self.conditioning_embedder = nn.Sequential(nn.Conv1d(in_channels, model_channels // 2, 3, padding=1, stride=2), - nn.Conv1d(model_channels//2, model_channels,3,padding=1,stride=2)) - self.conditioning_encoder = Encoder( - dim=model_channels, - depth=4, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) - - self.codes_encoder = Encoder( - dim=model_channels, - depth=8, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rms_scaleshift_norm=True, - ff_glu=True, - rotary_pos_emb=True, - zero_init_branch_output=True, - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) - ) - ] - ) - token_conditioning_blocks = [] - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - - for level, (blk_chan, num_blocks) in enumerate(zip(block_channels, num_res_blocks)): - if ds in token_conditioning_resolutions: - token_conditioning_block = nn.Conv1d(model_channels, ch, 1) - token_conditioning_block.weight.data *= .02 - self.input_blocks.append(token_conditioning_block) - token_conditioning_blocks.append(token_conditioning_block) - - for _ in range(num_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=blk_chan, - dims=dims, - kernel_size=kernel_size, - ) - ] - ch = blk_chan - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(block_channels) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=1, pad=0 - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - ), - AttentionBlock( - ch, - num_heads=num_heads, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, (blk_chan, num_blocks) in list(enumerate(zip(block_channels, num_res_blocks)))[::-1]: - for i in range(num_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=blk_chan, - dims=dims, - kernel_size=kernel_size, - ) - ] - ch = blk_chan - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - ) - ) - if level and i == num_blocks: - out_ch = ch - layers.append( - Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), - ) - - def forward(self, x, timesteps, codes, conditioning_input=None): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param codes: an aligned text input. - :return: an [N x C x ...] Tensor of outputs. - """ - with autocast(x.device.type): - orig_x_shape = x.shape[-1] - cm = ceil_multiple(x.shape[-1], 16) - if cm != 0: - pc = (cm-x.shape[-1])/x.shape[-1] - x = F.pad(x, (0,cm-x.shape[-1])) - codes = F.pad(codes, (0, int(pc * codes.shape[-1]))) - - hs = [] - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - # Mask out guidance tokens for un-guided diffusion. - if self.training and self.nil_guidance_fwd_proportion > 0: - token_mask = torch.rand(codes.shape, device=codes.device) < self.nil_guidance_fwd_proportion - codes = torch.where(token_mask, self.mask_token_id, codes) - code_emb = self.code_embedding(codes).permute(0, 2, 1) - cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) - cond_emb = self.conditioning_encoder(cond_emb)[:, 0] - code_emb = self.codes_encoder(code_emb.permute(0,2,1), norm_scale_shift_inp=cond_emb).permute(0,2,1) - - first = True - time_emb = time_emb.float() - h = x - for k, module in enumerate(self.input_blocks): - if isinstance(module, nn.Conv1d): - h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') - h = h + h_tok - else: - with autocast(x.device.type, enabled=not first): - # First block has autocast disabled to allow a high precision signal to be properly vectorized. - h = module(h, time_emb) - hs.append(h) - first = False - h = self.middle_block(h, time_emb) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, time_emb) - - # Last block also has autocast disabled for high-precision outputs. - h = h.float() - out = self.out(h) - return out[:, :, :orig_x_shape] - - -@register_model -def register_diffusion_tts10(opt_net, opt): - return DiffusionTts(**opt_net['kwargs']) - - -if __name__ == '__main__': - clip = torch.randn(2, 100, 500).cuda() - tok = torch.randint(0,256, (2,230)).cuda() - cond = torch.randn(2, 100, 300).cuda() - ts = torch.LongTensor([600, 600]).cuda() - model = DiffusionTts(512).cuda() - print(sum(p.numel() for p in model.parameters()) / 1000000) - model(clip, ts, tok, cond) - diff --git a/codes/models/lucidrains/x_transformers.py b/codes/models/lucidrains/x_transformers.py index b48eb51e..158d358c 100644 --- a/codes/models/lucidrains/x_transformers.py +++ b/codes/models/lucidrains/x_transformers.py @@ -352,12 +352,12 @@ class RMSNorm(nn.Module): class RMSScaleShiftNorm(nn.Module): - def __init__(self, dim, eps=1e-8): + def __init__(self, dim, eps=1e-8, bias=True): super().__init__() self.scale = dim ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) - self.scale_shift_process = nn.Linear(dim, dim * 2) + self.scale_shift_process = nn.Linear(dim, dim * 2, bias=bias) def forward(self, x, norm_scale_shift_inp): norm = torch.norm(x, dim=-1, keepdim=True) * self.scale