diff --git a/codes/models/diffusion/diffusion_dvae.py b/codes/models/diffusion/diffusion_dvae.py index 5ac5bf25..f489e241 100644 --- a/codes/models/diffusion/diffusion_dvae.py +++ b/codes/models/diffusion/diffusion_dvae.py @@ -5,6 +5,7 @@ from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, Res import torch import torch.nn as nn +from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner from models.vqvae.vqvae import Quantize from trainer.networks import register_model import models.gpt_voice.my_dvae as mdvae @@ -120,6 +121,9 @@ class DiffusionDVAE(nn.Module): nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) + self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim) + self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim) + self.embedding_combiner = EmbeddingCombiner(time_embed_dim) self.input_blocks = nn.ModuleList( [ @@ -258,7 +262,7 @@ class DiffusionDVAE(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps, spectrogram): + def forward(self, x, timesteps, spectrogram, conditioning_inputs=None): assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. # Compute DVAE portion first. @@ -275,9 +279,17 @@ class DiffusionDVAE(nn.Module): spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)] convergence_fns = list(self.convergence_convs) - # The rest is the diffusion vocoder, built as a standard U-net. spec_h is gradually fed into the encoder. + # Timestep embeddings and conditioning signals are combined using a small transformer. hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + if conditioning_inputs is not None: + emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1) + emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1) + else: + emb = emb1.unsqueeze(1) + emb = self.embedding_combiner(emb, self.query_gen(spec_hs[0])) + + # The rest is the diffusion vocoder, built as a standard U-net. spec_h is gradually fed into the encoder. next_spec = spec_hs.pop(0) next_convergence_fn = convergence_fns.pop(0) h = x.type(self.dtype) @@ -311,6 +323,7 @@ def register_unet_diffusion_dvae(opt_net, opt): if __name__ == '__main__': clip = torch.randn(1, 1, 81920) spec = torch.randn(1, 80, 416) + cond = torch.randn(1, 5, 80, 200) ts = torch.LongTensor([555]) model = DiffusionDVAE(32, 2) - print(model(clip, ts, spec).shape) + print(model(clip, ts, spec, cond)[0].shape) diff --git a/codes/models/gpt_voice/mini_encoder.py b/codes/models/gpt_voice/mini_encoder.py new file mode 100644 index 00000000..84fd58fa --- /dev/null +++ b/codes/models/gpt_voice/mini_encoder.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn + + +from models.diffusion.nn import normalization, conv_nd, zero_module +from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy +from models.gpt_voice.my_dvae import ResBlock + + +# Combined resnet & full-attention encoder for converting an audio clip into an embedding. +from utils.util import checkpoint + + +class AudioMiniEncoder(nn.Module): + def __init__(self, spec_dim, embedding_dim, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, dropout=0): + super().__init__() + self.init = nn.Sequential( + conv_nd(1, spec_dim, 128, 3, padding=1) + ) + ch = 128 + res = [] + for l in range(2): + for r in range(resnet_blocks): + res.append(ResBlock(ch, dropout, dims=1)) + res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=2)) + ch *= 2 + self.res = nn.Sequential(*res) + self.final = nn.Sequential( + normalization(ch), + nn.SiLU(), + conv_nd(1, ch, embedding_dim, 1) + ) + attn = [] + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + self.attn = nn.Sequential(*attn) + + def forward(self, x): + h = self.init(x) + h = self.res(h) + h = self.final(h) + h = self.attn(h) + return h[:, :, 0] + + + + +class QueryProvidedAttentionBlock(nn.Module): + """ + An attention block that provides a separate signal for the query vs the keys/parameters. + """ + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.q = nn.Linear(channels, channels) + self.qnorm = nn.LayerNorm(channels) + self.kv = conv_nd(1, channels, channels*2, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, qx, kvx): + return checkpoint(self._forward, qx, kvx) + + def _forward(self, qx, kvx): + q = self.q(self.qnorm(qx)).unsqueeze(1).repeat(1, kvx.shape[1], 1).permute(0,2,1) + kv = self.kv(self.norm(kvx.permute(0,2,1))) + qkv = torch.cat([q, kv], dim=1) + h = self.attention(qkv) + h = self.proj_out(h) + return kvx + h.permute(0,2,1) + + +# Next up: combine multiple embeddings given a conditioning signal into a single embedding. +class EmbeddingCombiner(nn.Module): + def __init__(self, embedding_dim, attn_blocks=3, num_attn_heads=2, cond_provided=True): + super().__init__() + block = QueryProvidedAttentionBlock if cond_provided else AttentionBlock + self.attn = nn.ModuleList([block(embedding_dim, num_attn_heads) for _ in range(attn_blocks)]) + self.cond_provided = cond_provided + + # x_s: (b,n,d); b=batch_sz, n=number of embeddings, d=embedding_dim + # cond: (b,d) or None + def forward(self, x_s, cond=None): + assert cond is not None and self.cond_provided or cond is None and not self.cond_provided + y = x_s + for blk in self.attn: + if self.cond_provided: + y = blk(cond, y) + else: + y = blk(y) + return y[:, 0] + + +if __name__ == '__main__': + x = torch.randn(2, 80, 223) + cond = torch.randn(2, 512) + encs = [AudioMiniEncoder(80, 512) for _ in range(5)] + combiner = EmbeddingCombiner(512) + + e = torch.stack([e(x) for e in encs], dim=2) + + print(combiner(e, cond).shape)