diff --git a/codes/models/audio/tts/unet_diffusion_tts7.py b/codes/models/audio/tts/unet_diffusion_tts7.py index 1894f1df..a323c9b0 100644 --- a/codes/models/audio/tts/unet_diffusion_tts7.py +++ b/codes/models/audio/tts/unet_diffusion_tts7.py @@ -57,10 +57,11 @@ class CheckpointedXTransformerEncoder(nn.Module): Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid to channels-last that XTransformer expects. """ - def __init__(self, needs_permute=True, checkpoint=True, **xtransformer_kwargs): + def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): super().__init__() self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) self.needs_permute = needs_permute + self.exit_permute = exit_permute if not checkpoint: return @@ -72,7 +73,9 @@ class CheckpointedXTransformerEncoder(nn.Module): if self.needs_permute: x = x.permute(0,2,1) h = self.transformer(x, **kwargs) - return h.permute(0,2,1) + if self.exit_permute: + h = h.permute(0,2,1) + return h class ResBlock(TimestepBlock): diff --git a/codes/models/clip/text_voice_clip.py b/codes/models/clip/text_voice_clip.py index 220b14c5..4a3fd92b 100644 --- a/codes/models/clip/text_voice_clip.py +++ b/codes/models/clip/text_voice_clip.py @@ -3,7 +3,9 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch import einsum +from x_transformers import Encoder +from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder from models.lucidrains.dalle.transformer import Transformer from trainer.networks import register_model from utils.util import opt_get @@ -43,40 +45,69 @@ class VoiceCLIP(nn.Module): text_mask_percentage=0, voice_mask_percentage=0, wav_token_compression=1024, + use_xformers=False, ): super().__init__() self.text_emb = nn.Embedding(num_text_tokens, dim_text) - self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) - self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, - heads=text_heads, rotary_emb=False) self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False) self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech) - self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) - self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, - depth=speech_enc_depth, heads=speech_heads, rotary_emb=False) self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) + if use_xformers: + self.text_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + use_pos_emb=False, + attn_layers=Encoder( + dim=dim_text, + depth=text_enc_depth, + heads=text_heads, + ff_dropout=.1, + ff_mult=2, + attn_dropout=.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + )) + self.speech_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + use_pos_emb=False, + attn_layers=Encoder( + dim=dim_speech, + depth=speech_enc_depth, + heads=speech_heads, + ff_dropout=.1, + ff_mult=2, + attn_dropout=.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + )) + else: + self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, + heads=text_heads) + self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, + depth=speech_enc_depth, heads=speech_heads) + self.temperature = nn.Parameter(torch.tensor(1.)) self.text_mask_percentage = text_mask_percentage self.voice_mask_percentage = voice_mask_percentage self.wav_token_compression = wav_token_compression + self.xformers = use_xformers + if not use_xformers: + self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) + self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) def forward( self, text, - text_lengths, speech_tokens, - wav_lengths, return_loss=False ): - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_text_len = text_lengths.max() - text = text[:, :max_text_len] - max_mel_len = wav_lengths.max() // self.wav_token_compression - speech_tokens = speech_tokens[:, :max_mel_len] - b, device = text.shape[0], text.device if self.training: text_mask = torch.rand_like(text.float()) > self.text_mask_percentage @@ -86,10 +117,11 @@ class VoiceCLIP(nn.Module): voice_mask = torch.ones_like(speech_tokens.float()).bool() text_emb = self.text_emb(text) - text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) - speech_emb = self.speech_emb(speech_tokens) - speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) + + if not self.xformers: + text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) + speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) enc_text = self.text_transformer(text_emb, mask=text_mask) enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) @@ -120,15 +152,11 @@ def register_voice_clip(opt_net, opt): if __name__ == '__main__': - clip = VoiceCLIP(text_mask_percentage=.2, voice_mask_percentage=.2) + clip = VoiceCLIP(text_mask_percentage=.2, voice_mask_percentage=.2, use_xformers=True) clip(torch.randint(0,256,(2,120)), - torch.tensor([50,100]), torch.randint(0,8192,(2,250)), - torch.tensor([101,102]), return_loss=True) nonloss = clip(torch.randint(0,256,(2,120)), - torch.tensor([50,100]), torch.randint(0,8192,(2,250)), - torch.tensor([101,102]), return_loss=False) print(nonloss.shape) \ No newline at end of file