From 71b73db044d2a5ed87fcfe812d848daf96f0ad43 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 7 Apr 2022 11:34:10 -0600 Subject: [PATCH] clean up --- codes/models/audio/tts/ctc_code_generator.py | 191 -------- codes/models/audio/tts/unet_diffusion_tts5.py | 458 ------------------ codes/models/audio/tts/unet_diffusion_tts6.py | 458 ------------------ codes/models/audio/tts/unet_diffusion_tts8.py | 312 ------------ 4 files changed, 1419 deletions(-) delete mode 100644 codes/models/audio/tts/ctc_code_generator.py delete mode 100644 codes/models/audio/tts/unet_diffusion_tts5.py delete mode 100644 codes/models/audio/tts/unet_diffusion_tts6.py delete mode 100644 codes/models/audio/tts/unet_diffusion_tts8.py diff --git a/codes/models/audio/tts/ctc_code_generator.py b/codes/models/audio/tts/ctc_code_generator.py deleted file mode 100644 index 653ffcbe..00000000 --- a/codes/models/audio/tts/ctc_code_generator.py +++ /dev/null @@ -1,191 +0,0 @@ -import json - -import torch -import torch.nn as nn -import torch.nn.functional as F -from x_transformers import Encoder, TransformerWrapper - -from models.audio.tts.unet_diffusion_tts6 import CheckpointedLayer -from models.audio.tts.unified_voice2 import ConditioningEncoder -from models.audio.tts.tacotron2.text.cleaners import english_cleaners -from trainer.networks import register_model -from utils.util import opt_get - - -def clustered_mask(probability, shape, dev, lateral_expansion_radius_max=3): - """ - Produces a masking vector of the specified shape where each element has probability to be zero. - lateral_expansion_radius_max neighbors of any element that is zero also have a 50% chance to be zero. - Effectively, this produces clusters of masks tending to be lateral_expansion_radius_max wide. - - Note: This means the algorithm has a far higher output probability for zeros then . - """ - mask = torch.rand(shape, device=dev) - mask = (mask < probability).float() - kernel = torch.tensor([.5 for _ in range(lateral_expansion_radius_max)] + [1] + [.5 for _ in range(lateral_expansion_radius_max)], device=dev) - mask = F.conv1d(mask.unsqueeze(1), kernel.view(1,1,2*lateral_expansion_radius_max+1), padding=lateral_expansion_radius_max).squeeze(1) - return torch.bernoulli(torch.clamp(mask, 0, 1)) == 0 # ==0 logically inverts the mask. - - -class CheckpointedTransformerWrapper(nn.Module): - """ - Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid - to channels-last that XTransformer expects. - """ - def __init__(self, **xtransformer_kwargs): - super().__init__() - self.transformer = TransformerWrapper(**xtransformer_kwargs) - - for i in range(len(self.transformer.transformer.attn_layers.layers)): - n, b, r = self.transformer.transformer.attn_layers.layers[i] - self.transformer.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) - - def forward(self, *args, **kwargs): - return self.transformer(*args, **kwargs) - - -class CtcCodeGenerator(nn.Module): - def __init__(self, model_dim=512, layers=10, num_heads=8, dropout=.1, ctc_codes=36, max_pad=121, max_repeat=30, mask_probability=.1): - super().__init__() - self.max_pad = max_pad - self.max_repeat = max_repeat - self.mask_probability = mask_probability - self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=num_heads, mean=True) - self.initial_embedding = nn.Embedding(ctc_codes, model_dim) - self.combiner = nn.Linear(model_dim*2, model_dim) - self.transformer = TransformerWrapper( - num_tokens=max_pad*max_repeat+1, - max_seq_len=-1, # Unneeded for rotary embeddings. - attn_layers=Encoder( - dim=model_dim, - depth=layers, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True - ) - ) - self.transformer.token_emb = nn.Identity() # This class handles the initial embeddings. - self.transformer.to_logits = nn.Identity() - self.ctc_head = nn.Linear(model_dim, max_pad*max_repeat+1) - self.inp_head = nn.Linear(model_dim, ctc_codes) - - def forward(self, conditioning_input, codes, separators, repeats, unpadded_lengths): - max_len = unpadded_lengths.max() - codes = codes[:, :max_len] - loss_mask = torch.ones_like(codes) - for i, l in enumerate(unpadded_lengths): - loss_mask[i, l:] = 0 - if self.training: - codes = clustered_mask(self.mask_probability, codes.shape, codes.device) * codes - - if separators.max() > self.max_pad: - print(f"Got unexpectedly long separators. Max: {separators.max()}, {separators}") - separators = torch.clip(separators, 0, self.max_pad) - separators = separators[:, :max_len] - if repeats.max() > self.max_repeat: - print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}") - repeats = torch.clip(repeats, 1, self.max_repeat) - repeats = repeats[:, :max_len] - repeats = repeats - 1 # min(repeats) is 1; make it 0 to avoid wasting a prediction slot. - labels = separators + repeats * self.max_pad - - # Perform conditioning encoder in FP32, with the transformer in FP16 - cond = self.conditioning_encoder(conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1) - h = torch.cat([cond, self.initial_embedding(codes)], dim=-1) - h = self.combiner(h) - with torch.autocast(codes.device.type): - logits = self.transformer(h) - ctc_pred = self.ctc_head(logits) - code_pred = self.inp_head(logits) - - ctcloss = F.cross_entropy(ctc_pred.float().permute(0,2,1), labels, reduction='none') - ctcloss = torch.mean(ctcloss * loss_mask) - codeloss = F.cross_entropy(code_pred.float().permute(0,2,1), codes, reduction='none') - codeloss = torch.mean(codeloss * loss_mask) - return ctcloss, codeloss - - def generate(self, speech_conditioning_input, texts): - codes = [] - max_seq = 50 - for text in texts: - # First, generate CTC codes from the given texts. - vocab = json.loads('{" ": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "\'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}') - text = english_cleaners(text) - text = text.strip().upper() - cd = [] - for c in text: - if c not in vocab.keys(): - continue - cd.append(vocab[c]) - codes.append(torch.tensor(cd, device=speech_conditioning_input.device)) - max_seq = max(max_seq, codes[-1].shape[-1]) - # Collate - for i in range(len(codes)): - if codes[i].shape[-1] < max_seq: - codes[i] = F.pad(codes[i], (0, max_seq-codes[i].shape[-1])) - codes = torch.stack(codes, dim=0) - - cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1).repeat(1,codes.shape[1],1) - h = torch.cat([cond, self.initial_embedding(codes)], dim=-1) - h = self.combiner(h) - with torch.autocast(codes.device.type): - logits = self.transformer(h) - ctc_pred = self.ctc_head(logits) - generate = torch.argmax(ctc_pred, dim=-1) - - # De-compress the codes from the generated output - pads = generate % self.max_pad - repeats = (generate // self.max_pad) + 1 - ctc_batch = [] - max_seq = 0 - for bc, bp, br in zip(codes, pads, repeats): - ctc = [] - for c, p, r in zip(bc, bp, br): - for _ in range(p): - ctc.append(0) - for _ in range(r): - ctc.append(c.item()) - ctc_batch.append(torch.tensor(ctc, device=speech_conditioning_input.device)) - max_seq = max(max_seq, ctc_batch[-1].shape[-1]) - - # Collate the batch - for i in range(len(ctc_batch)): - if ctc_batch[i].shape[-1] < max_seq: - ctc_batch[i] = F.pad(ctc_batch[i], (0, max_seq-ctc_batch[i].shape[-1])) - return torch.stack(ctc_batch, dim=0) - -@register_model -def register_ctc_code_generator(opt_net, opt): - return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {})) - - -def inf(): - sd = torch.load('D:\\dlas\\experiments\\train_encoder_build_ctc_alignments_medium\\models\\24000_generator.pth', map_location='cpu') - model = CtcCodeGenerator(model_dim=1024,layers=32).eval() - model.load_state_dict(sd) - with torch.no_grad(): - from data.audio.unsupervised_audio_dataset import load_audio - from scripts.audio.gen.speech_synthesis_utils import wav_to_mel - ref_mel = torch.cat([wav_to_mel(load_audio("D:\\tortoise-tts\\voices\\atkins\\1.wav", 22050))[:,:,:450], - wav_to_mel(load_audio("D:\\tortoise-tts\\voices\\kennard\\1.wav", 22050))[:,:,:450], - wav_to_mel(load_audio("D:\\tortoise-tts\\voices\\grace\\1.wav", 22050))[:,:,:450], - wav_to_mel(load_audio("D:\\tortoise-tts\\voices\\atkins\\1.wav", 22050))[:,:,:450]], dim=0) - ctc = model.generate(ref_mel, (["i suppose though it's too early for them"] * 3) + ["i suppose though it's too early for them, dear"]) - print("Break") - - -if __name__ == '__main__': - #inf() - - mask = clustered_mask(.1, (4,100), 'cpu') - - model = CtcCodeGenerator() - inps = torch.randint(0,36, (4, 300)) - pads = torch.randint(0,100, (4,300)) - repeats = torch.randint(1,20, (4,300)) - conds = torch.randn(4,80,600) - loss1, loss2 = model(conds, inps, pads, repeats, torch.tensor([250, 300, 280, 30])) - print(loss1.shape, loss2.shape) \ No newline at end of file diff --git a/codes/models/audio/tts/unet_diffusion_tts5.py b/codes/models/audio/tts/unet_diffusion_tts5.py deleted file mode 100644 index 0339877d..00000000 --- a/codes/models/audio/tts/unet_diffusion_tts5.py +++ /dev/null @@ -1,458 +0,0 @@ -import functools -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import autocast - -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ - Downsample, Upsample, TimestepBlock -from models.audio.tts.mini_encoder import AudioMiniEncoder -from scripts.audio.gen.use_diffuse_tts import ceil_multiple -from trainer.networks import register_model -from utils.util import checkpoint -from x_transformers import Encoder, ContinuousTransformerWrapper - - -class CheckpointedLayer(nn.Module): - """ - Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses - checkpoint for all other args. - """ - def __init__(self, wrap): - super().__init__() - self.wrap = wrap - - def forward(self, x, **kwargs): - kw_requires_grad = {} - kw_no_grad = {} - for k, v in kwargs.items(): - if v is not None and isinstance(v, torch.Tensor) and v.requires_grad: - kw_requires_grad[k] = v - else: - kw_no_grad[k] = v - partial = functools.partial(self.wrap, **kw_no_grad) - return torch.utils.checkpoint.checkpoint(partial, x, **kw_requires_grad) - - -class CheckpointedXTransformerEncoder(nn.Module): - """ - Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid - to channels-last that XTransformer expects. - """ - def __init__(self, **xtransformer_kwargs): - super().__init__() - self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) - - for i in range(len(self.transformer.attn_layers.layers)): - n, b, r = self.transformer.attn_layers.layers[i] - self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) - - def forward(self, x): - x = x.permute(0,2,1) - h = self.transformer(x) - return h.permute(0,2,1) - - -class ResBlock(TimestepBlock): - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - dims=2, - kernel_size=3, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - padding = 1 if kernel_size == 3 else 2 - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 1, padding=0), - ) - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, x, emb - ) - - def _forward(self, x, emb): - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class DiffusionTts(nn.Module): - """ - The full UNet model with attention and timestep embedding. - - Customized to be conditioned on an aligned token prior. - - :param in_channels: channels in the input Tensor. - :param num_tokens: number of tokens (e.g. characters) which can be provided. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - model_channels, - in_channels=1, - num_tokens=32, - out_channels=2, # mean and variance - dropout=0, - # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K - channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), - num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2), - # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) - # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 - token_conditioning_resolutions=(1,16,), - attention_resolutions=(512,1024,2048), - conv_resample=True, - dims=1, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - kernel_size=3, - scale_factor=2, - conditioning_inputs_provided=True, - time_embed_dim_multiplier=4, - transformer_depths=8, - nil_guidance_fwd_proportion=.3, - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.dtype = torch.float16 if use_fp16 else torch.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.dims = dims - self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion - self.mask_token_id = num_tokens - - padding = 1 if kernel_size == 3 else 2 - - time_embed_dim = model_channels * time_embed_dim_multiplier - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - embedding_dim = model_channels * 8 - self.code_embedding = nn.Embedding(num_tokens+1, embedding_dim) - self.conditioning_enabled = conditioning_inputs_provided - if conditioning_inputs_provided: - self.contextual_embedder = AudioMiniEncoder(in_channels, embedding_dim, base_channels=32, depth=6, resnet_blocks=1, - attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) - self.conditioning_encoder = CheckpointedXTransformerEncoder( - max_seq_len=-1, # Should be unused - use_pos_emb=False, - attn_layers=Encoder( - dim=embedding_dim, - depth=transformer_depths, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) - ) - ] - ) - token_conditioning_blocks = [] - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - - for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): - if ds in token_conditioning_resolutions: - token_conditioning_block = nn.Conv1d(embedding_dim, ch, 1) - token_conditioning_block.weight.data *= .02 - self.input_blocks.append(token_conditioning_block) - token_conditioning_blocks.append(token_conditioning_block) - - for _ in range(num_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=int(mult * model_channels), - dims=dims, - kernel_size=kernel_size, - ) - ] - ch = int(mult * model_channels) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=1, pad=0 - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - mid_transformer = CheckpointedXTransformerEncoder( - max_seq_len=-1, # Should be unused - use_pos_emb=False, - attn_layers=Encoder( - dim=ch, - depth=transformer_depths, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) - ) - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - kernel_size=kernel_size, - ), - mid_transformer, - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - kernel_size=kernel_size, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]: - for i in range(num_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=int(model_channels * mult), - dims=dims, - kernel_size=kernel_size, - ) - ] - ch = int(model_channels * mult) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads_upsample, - num_head_channels=num_head_channels, - ) - ) - if level and i == num_blocks: - out_ch = ch - layers.append( - Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), - ) - - def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', - strict: bool = True): - # Temporary hack to allow the addition of nil-guidance token embeddings to the existing guidance embeddings. - lsd = self.state_dict() - revised = 0 - for i, blk in enumerate(self.input_blocks): - if isinstance(blk, nn.Embedding): - key = f'input_blocks.{i}.weight' - if state_dict[key].shape[0] != lsd[key].shape[0]: - t = torch.randn_like(lsd[key]) * .02 - t[:state_dict[key].shape[0]] = state_dict[key] - state_dict[key] = t - revised += 1 - print(f"Loaded experimental unet_diffusion_net with {revised} modifications.") - return super().load_state_dict(state_dict, strict) - - - - def forward(self, x, timesteps, tokens, conditioning_input=None): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param tokens: an aligned text input. - :return: an [N x C x ...] Tensor of outputs. - """ - with autocast(x.device.type): - orig_x_shape = x.shape[-1] - cm = ceil_multiple(x.shape[-1], 2048) - if cm != 0: - pc = (cm-x.shape[-1])/x.shape[-1] - x = F.pad(x, (0,cm-x.shape[-1])) - tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1]))) - if self.conditioning_enabled: - assert conditioning_input is not None - - hs = [] - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - # Mask out guidance tokens for un-guided diffusion. - if self.training and self.nil_guidance_fwd_proportion > 0: - token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion - tokens = torch.where(token_mask, self.mask_token_id, tokens) - code_emb = self.code_embedding(tokens).permute(0,2,1) - if self.conditioning_enabled: - cond_emb = self.contextual_embedder(conditioning_input) - code_emb = cond_emb.unsqueeze(-1) * code_emb - code_emb = self.conditioning_encoder(code_emb) - - first = True - time_emb = time_emb.float() - h = x - for k, module in enumerate(self.input_blocks): - if isinstance(module, nn.Conv1d): - h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') - h = h + h_tok - else: - with autocast(x.device.type, enabled=not first): - # First block has autocast disabled to allow a high precision signal to be properly vectorized. - h = module(h, time_emb) - hs.append(h) - first = False - h = self.middle_block(h, time_emb) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, time_emb) - - # Last block also has autocast disabled for high-precision outputs. - h = h.float() - out = self.out(h) - return out[:, :, :orig_x_shape] - - -@register_model -def register_diffusion_tts5(opt_net, opt): - return DiffusionTts(**opt_net['kwargs']) - - -# Test for ~4 second audio clip at 22050Hz -if __name__ == '__main__': - clip = torch.randn(2, 1, 32768) - tok = torch.randint(0,30, (2,388)) - cond = torch.randn(2, 1, 44000) - ts = torch.LongTensor([600, 600]) - model = DiffusionTts(128, - channel_mult=[1,1.5,2, 3, 4, 6, 8], - num_res_blocks=[2, 2, 2, 2, 2, 2, 1], - token_conditioning_resolutions=[1,4,16,64], - attention_resolutions=[], - num_heads=8, - kernel_size=3, - scale_factor=2, - conditioning_inputs_provided=True, - time_embed_dim_multiplier=4) - model(clip, ts, tok, cond) - torch.save(model.state_dict(), 'test_out.pth') - diff --git a/codes/models/audio/tts/unet_diffusion_tts6.py b/codes/models/audio/tts/unet_diffusion_tts6.py deleted file mode 100644 index c3d76625..00000000 --- a/codes/models/audio/tts/unet_diffusion_tts6.py +++ /dev/null @@ -1,458 +0,0 @@ -import functools -import random - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import autocast - -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ - Downsample, Upsample, TimestepBlock -from models.audio.tts.mini_encoder import AudioMiniEncoder -from scripts.audio.gen.use_diffuse_tts import ceil_multiple -from trainer.networks import register_model -from utils.util import checkpoint -from x_transformers import Encoder, ContinuousTransformerWrapper - - -class CheckpointedLayer(nn.Module): - """ - Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses - checkpoint for all other args. - """ - def __init__(self, wrap): - super().__init__() - self.wrap = wrap - - def forward(self, x, **kwargs): - kw_requires_grad = {} - kw_no_grad = {} - for k, v in kwargs.items(): - if v is not None and isinstance(v, torch.Tensor) and v.requires_grad: - kw_requires_grad[k] = v - else: - kw_no_grad[k] = v - partial = functools.partial(self.wrap, **kw_no_grad) - return torch.utils.checkpoint.checkpoint(partial, x, **kw_requires_grad) - - -class CheckpointedXTransformerEncoder(nn.Module): - """ - Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid - to channels-last that XTransformer expects. - """ - def __init__(self, **xtransformer_kwargs): - super().__init__() - self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) - - for i in range(len(self.transformer.attn_layers.layers)): - n, b, r = self.transformer.attn_layers.layers[i] - self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) - - def forward(self, x): - x = x.permute(0,2,1) - h = self.transformer(x) - return h.permute(0,2,1) - - -class ResBlock(TimestepBlock): - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - dims=2, - kernel_size=3, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - padding = 1 if kernel_size == 3 else 2 - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 1, padding=0), - ) - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, x, emb - ) - - def _forward(self, x, emb): - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class DiffusionTts(nn.Module): - """ - The full UNet model with attention and timestep embedding. - - Customized to be conditioned on an aligned token prior. - - :param in_channels: channels in the input Tensor. - :param num_tokens: number of tokens (e.g. characters) which can be provided. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - model_channels, - in_channels=1, - num_tokens=32, - out_channels=2, # mean and variance - dropout=0, - # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K - channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), - num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2), - # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) - # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 - token_conditioning_resolutions=(1,16,), - attention_resolutions=(512,1024,2048), - conv_resample=True, - dims=1, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - kernel_size=3, - scale_factor=2, - time_embed_dim_multiplier=4, - cond_transformer_depth=8, - mid_transformer_depth=8, - nil_guidance_fwd_proportion=.3, - super_sampling=False, - super_sampling_max_noising_factor=.1, - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if super_sampling: - in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input. - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.dtype = torch.float16 if use_fp16 else torch.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.dims = dims - self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion - self.mask_token_id = num_tokens - self.super_sampling_enabled = super_sampling - self.super_sampling_max_noising_factor = super_sampling_max_noising_factor - padding = 1 if kernel_size == 3 else 2 - - time_embed_dim = model_channels * time_embed_dim_multiplier - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - embedding_dim = model_channels * 8 - self.code_embedding = nn.Embedding(num_tokens+1, embedding_dim) - self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1, - attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) - self.conditioning_conv = nn.Conv1d(embedding_dim*2, embedding_dim, 1) - self.conditioning_encoder = CheckpointedXTransformerEncoder( - max_seq_len=-1, # Should be unused - use_pos_emb=False, - attn_layers=Encoder( - dim=embedding_dim, - depth=cond_transformer_depth, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) - ) - ] - ) - token_conditioning_blocks = [] - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - - for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): - if ds in token_conditioning_resolutions: - token_conditioning_block = nn.Conv1d(embedding_dim, ch, 1) - token_conditioning_block.weight.data *= .02 - self.input_blocks.append(token_conditioning_block) - token_conditioning_blocks.append(token_conditioning_block) - - for _ in range(num_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=int(mult * model_channels), - dims=dims, - kernel_size=kernel_size, - ) - ] - ch = int(mult * model_channels) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=1, pad=0 - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - mid_transformer = CheckpointedXTransformerEncoder( - max_seq_len=-1, # Should be unused - use_pos_emb=False, - attn_layers=Encoder( - dim=ch, - depth=mid_transformer_depth, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) - ) - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - kernel_size=kernel_size, - ), - mid_transformer, - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - kernel_size=kernel_size, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]: - for i in range(num_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=int(model_channels * mult), - dims=dims, - kernel_size=kernel_size, - ) - ] - ch = int(model_channels * mult) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads_upsample, - num_head_channels=num_head_channels, - ) - ) - if level and i == num_blocks: - out_ch = ch - layers.append( - Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), - ) - - - def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param tokens: an aligned text input. - :param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded. - :param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate. - :return: an [N x C x ...] Tensor of outputs. - """ - assert conditioning_input is not None - if self.super_sampling_enabled: - assert lr_input is not None - if self.training and self.super_sampling_max_noising_factor > 0: - noising_factor = random.uniform(0,self.super_sampling_max_noising_factor) - lr_input = torch.randn_like(lr_input) * noising_factor + lr_input - lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest') - x = torch.cat([x, lr_input], dim=1) - - with autocast(x.device.type): - orig_x_shape = x.shape[-1] - cm = ceil_multiple(x.shape[-1], 2048) - if cm != 0: - pc = (cm-x.shape[-1])/x.shape[-1] - x = F.pad(x, (0,cm-x.shape[-1])) - if tokens is not None: - tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1]))) - - hs = [] - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - cond_emb = self.contextual_embedder(conditioning_input) - if tokens is not None: - # Mask out guidance tokens for un-guided diffusion. - if self.training and self.nil_guidance_fwd_proportion > 0: - token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion - tokens = torch.where(token_mask, self.mask_token_id, tokens) - code_emb = self.code_embedding(tokens).permute(0,2,1) - code_emb = self.conditioning_conv(torch.cat([cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]), code_emb], dim=1)) - else: - code_emb = cond_emb.unsqueeze(-1) - code_emb = self.conditioning_encoder(code_emb) - - first = True - time_emb = time_emb.float() - h = x - for k, module in enumerate(self.input_blocks): - if isinstance(module, nn.Conv1d): - h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') - h = h + h_tok - else: - with autocast(x.device.type, enabled=not first): - # First block has autocast disabled to allow a high precision signal to be properly vectorized. - h = module(h, time_emb) - hs.append(h) - first = False - h = self.middle_block(h, time_emb) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, time_emb) - - # Last block also has autocast disabled for high-precision outputs. - h = h.float() - out = self.out(h) - return out[:, :, :orig_x_shape] - - -@register_model -def register_diffusion_tts6(opt_net, opt): - return DiffusionTts(**opt_net['kwargs']) - - -# Test for ~4 second audio clip at 22050Hz -if __name__ == '__main__': - clip = torch.randn(2, 1, 32768) - tok = torch.randint(0,30, (2,388)) - cond = torch.randn(2, 1, 44000) - ts = torch.LongTensor([600, 600]) - lr = torch.randn(2,1,10000) - model = DiffusionTts(128, - channel_mult=[1,1.5,2, 3, 4, 6, 8], - num_res_blocks=[2, 2, 2, 2, 2, 2, 1], - token_conditioning_resolutions=[1,4,16,64], - attention_resolutions=[], - num_heads=8, - kernel_size=3, - scale_factor=2, - time_embed_dim_multiplier=4, super_sampling=True) - model(clip, ts, tok, cond, lr) - model(clip, ts, None, cond, lr) - torch.save(model.state_dict(), 'test_out.pth') - diff --git a/codes/models/audio/tts/unet_diffusion_tts8.py b/codes/models/audio/tts/unet_diffusion_tts8.py deleted file mode 100644 index 13c46246..00000000 --- a/codes/models/audio/tts/unet_diffusion_tts8.py +++ /dev/null @@ -1,312 +0,0 @@ -import functools -import random - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import autocast - -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import TimestepEmbedSequential, \ - Downsample, Upsample -from models.audio.tts.mini_encoder import AudioMiniEncoder -from scripts.audio.gen.use_diffuse_tts import ceil_multiple -from trainer.networks import register_model -from x_transformers import Encoder, ContinuousTransformerWrapper - - -class CheckpointedLayer(nn.Module): - """ - Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses - checkpoint for all other args. - """ - def __init__(self, wrap): - super().__init__() - self.wrap = wrap - - def forward(self, x, *args, **kwargs): - for k, v in kwargs.items(): - assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. - partial = functools.partial(self.wrap, **kwargs) - return torch.utils.checkpoint.checkpoint(partial, x, *args) - - -class CheckpointedXTransformerEncoder(nn.Module): - """ - Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid - to channels-last that XTransformer expects. - """ - def __init__(self, **xtransformer_kwargs): - super().__init__() - self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) - - for i in range(len(self.transformer.attn_layers.layers)): - n, b, r = self.transformer.attn_layers.layers[i] - self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) - - def forward(self, x, **kwargs): - x = x.permute(0,2,1) - h = self.transformer(x, **kwargs) - return h.permute(0,2,1) - - -class DiffusionTts(nn.Module): - def __init__( - self, - model_channels, - in_channels=1, - num_tokens=32, - out_channels=2, # mean and variance - dropout=0, - # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K - channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), - # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) - # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 - token_conditioning_resolutions=(1,16,), - dims=1, - use_fp16=False, - time_embed_dim_multiplier=4, - cond_transformer_depth=8, - mid_transformer_depth=8, - nil_guidance_fwd_proportion=.3, - # Parameters for super-sampling. - super_sampling=False, - super_sampling_max_noising_factor=.1, - # Parameters for unaligned inputs. - enabled_unaligned_inputs=False, - num_unaligned_tokens=164, - unaligned_encoder_depth=8, - ): - super().__init__() - - if super_sampling: - in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input. - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.dropout = dropout - self.channel_mult = channel_mult - self.dtype = torch.float16 if use_fp16 else torch.float32 - self.dims = dims - self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion - self.mask_token_id = num_tokens - self.super_sampling_enabled = super_sampling - self.super_sampling_max_noising_factor = super_sampling_max_noising_factor - - time_embed_dim = model_channels * time_embed_dim_multiplier - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - embedding_dim = model_channels * 8 - self.code_embedding = nn.Embedding(num_tokens+1, embedding_dim) - self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1, - attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) - self.conditioning_conv = nn.Conv1d(embedding_dim*3, embedding_dim, 1) - - self.enable_unaligned_inputs = enabled_unaligned_inputs - if enabled_unaligned_inputs: - self.unaligned_embedder = nn.Embedding(num_unaligned_tokens, embedding_dim) - self.unaligned_encoder = CheckpointedXTransformerEncoder( - max_seq_len=-1, - use_pos_emb=False, - attn_layers=Encoder( - dim=embedding_dim, - depth=unaligned_encoder_depth, - heads=embedding_dim//128, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_emb_dim=True, - ) - ) - - self.conditioning_encoder = CheckpointedXTransformerEncoder( - max_seq_len=-1, # Should be unused - use_pos_emb=False, - attn_layers=Encoder( - dim=embedding_dim, - depth=cond_transformer_depth, - heads=embedding_dim//128, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - cross_attend=self.enable_unaligned_inputs, - ) - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - token_conditioning_blocks = [] - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - - for level, mult in enumerate(channel_mult): - if ds in token_conditioning_resolutions: - token_conditioning_block = nn.Conv1d(embedding_dim, ch, 1) - token_conditioning_block.weight.data *= .02 - self.input_blocks.append(token_conditioning_block) - token_conditioning_blocks.append(token_conditioning_block) - - out_ch = int(mult * model_channels) - if level != len(channel_mult) - 1: - self.input_blocks.append( - TimestepEmbedSequential( - Downsample( - ch, use_conv=True, dims=dims, out_channels=out_ch, factor=2, ksize=3, pad=1 - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - - self.middle_block = CheckpointedXTransformerEncoder( - max_seq_len=-1, # Should be unused - use_pos_emb=False, - attn_layers=Encoder( - dim=ch, - depth=mid_transformer_depth, - heads=ch//128, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_pos_emb=True, - ) - ) - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - ich = ch + input_block_chans.pop() - out_ch = int(model_channels * mult) - if level != 0: - self.output_blocks.append( - TimestepEmbedSequential(Upsample(ich, use_conv=True, dims=dims, out_channels=out_ch, factor=2)) - ) - else: - self.output_blocks.append( - TimestepEmbedSequential(conv_nd(dims, ich, out_ch, 3, padding=1)) - ) - ch = out_ch - ds //= 2 - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - - - def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param tokens: an aligned text input. - :param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded. - :param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate. - :param unaligned_input: A structural input that is not properly aligned with the output of the diffusion model. - Can be combined with a conditioning input to produce more robust conditioning. - :return: an [N x C x ...] Tensor of outputs. - """ - assert conditioning_input is not None - if self.super_sampling_enabled: - assert lr_input is not None - if self.training and self.super_sampling_max_noising_factor > 0: - noising_factor = random.uniform(0,self.super_sampling_max_noising_factor) - lr_input = torch.randn_like(lr_input) * noising_factor + lr_input - lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest') - x = torch.cat([x, lr_input], dim=1) - - if self.enable_unaligned_inputs: - assert unaligned_input is not None - unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1) - unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1) - - with autocast(x.device.type): - orig_x_shape = x.shape[-1] - cm = ceil_multiple(x.shape[-1], 2048) - if cm != 0: - pc = (cm-x.shape[-1])/x.shape[-1] - x = F.pad(x, (0,cm-x.shape[-1])) - if tokens is not None: - tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1]))) - - hs = [] - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - cond_emb = self.contextual_embedder(conditioning_input) - if tokens is not None: - # Mask out guidance tokens for un-guided diffusion. - if self.training and self.nil_guidance_fwd_proportion > 0: - token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion - tokens = torch.where(token_mask, self.mask_token_id, tokens) - code_emb = self.code_embedding(tokens).permute(0,2,1) - cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) - cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove. - cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) - code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1)) - else: - code_emb = cond_emb.unsqueeze(-1) - if self.enable_unaligned_inputs: - code_emb = self.conditioning_encoder(code_emb, context=unaligned_h) - else: - code_emb = self.conditioning_encoder(code_emb) - - first = True - time_emb = time_emb.float() - h = x - for k, module in enumerate(self.input_blocks): - if isinstance(module, nn.Conv1d): - h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') - h = h + h_tok - else: - with autocast(x.device.type, enabled=not first): - # First block has autocast disabled to allow a high precision signal to be properly vectorized. - h = module(h, time_emb) - hs.append(h) - first = False - h = self.middle_block(h) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, time_emb) - - # Last block also has autocast disabled for high-precision outputs. - h = h.float() - out = self.out(h) - return out[:, :, :orig_x_shape] - - -@register_model -def register_diffusion_tts8(opt_net, opt): - return DiffusionTts(**opt_net['kwargs']) - - -# Test for ~4 second audio clip at 22050Hz -if __name__ == '__main__': - clip = torch.randn(2, 1, 32768) - tok = torch.randint(0,30, (2,388)) - cond = torch.randn(2, 1, 44000) - ts = torch.LongTensor([600, 600]) - lr = torch.randn(2,1,10000) - un = torch.randint(0,120, (2,100)) - model = DiffusionTts(128, - channel_mult=[1,1.5,2, 3, 4, 6, 8], - token_conditioning_resolutions=[1,4,16,64], - time_embed_dim_multiplier=4, super_sampling=False, - enabled_unaligned_inputs=True) - model(clip, ts, tok, cond, lr, un) -