Get diffusion_dvae functional

This commit is contained in:
James Betker 2021-09-14 17:43:31 -06:00
parent e513052fca
commit 0382660159
2 changed files with 139 additions and 4 deletions

View File

@ -5,6 +5,7 @@ from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, Res
import torch
import torch.nn as nn
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
import models.gpt_voice.my_dvae as mdvae
@ -120,6 +121,9 @@ class DiffusionDVAE(nn.Module):
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim)
self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim)
self.embedding_combiner = EmbeddingCombiner(time_embed_dim)
self.input_blocks = nn.ModuleList(
[
@ -258,7 +262,7 @@ class DiffusionDVAE(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps, spectrogram):
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None):
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
# Compute DVAE portion first.
@ -275,9 +279,17 @@ class DiffusionDVAE(nn.Module):
spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)]
convergence_fns = list(self.convergence_convs)
# The rest is the diffusion vocoder, built as a standard U-net. spec_h is gradually fed into the encoder.
# Timestep embeddings and conditioning signals are combined using a small transformer.
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if conditioning_inputs is not None:
emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1)
else:
emb = emb1.unsqueeze(1)
emb = self.embedding_combiner(emb, self.query_gen(spec_hs[0]))
# The rest is the diffusion vocoder, built as a standard U-net. spec_h is gradually fed into the encoder.
next_spec = spec_hs.pop(0)
next_convergence_fn = convergence_fns.pop(0)
h = x.type(self.dtype)
@ -311,6 +323,7 @@ def register_unet_diffusion_dvae(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(1, 1, 81920)
spec = torch.randn(1, 80, 416)
cond = torch.randn(1, 5, 80, 200)
ts = torch.LongTensor([555])
model = DiffusionDVAE(32, 2)
print(model(clip, ts, spec).shape)
print(model(clip, ts, spec, cond)[0].shape)

View File

@ -0,0 +1,122 @@
import torch
import torch.nn as nn
from models.diffusion.nn import normalization, conv_nd, zero_module
from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy
from models.gpt_voice.my_dvae import ResBlock
# Combined resnet & full-attention encoder for converting an audio clip into an embedding.
from utils.util import checkpoint
class AudioMiniEncoder(nn.Module):
def __init__(self, spec_dim, embedding_dim, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, dropout=0):
super().__init__()
self.init = nn.Sequential(
conv_nd(1, spec_dim, 128, 3, padding=1)
)
ch = 128
res = []
for l in range(2):
for r in range(resnet_blocks):
res.append(ResBlock(ch, dropout, dims=1))
res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=2))
ch *= 2
self.res = nn.Sequential(*res)
self.final = nn.Sequential(
normalization(ch),
nn.SiLU(),
conv_nd(1, ch, embedding_dim, 1)
)
attn = []
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn)
def forward(self, x):
h = self.init(x)
h = self.res(h)
h = self.final(h)
h = self.attn(h)
return h[:, :, 0]
class QueryProvidedAttentionBlock(nn.Module):
"""
An attention block that provides a separate signal for the query vs the keys/parameters.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = normalization(channels)
self.q = nn.Linear(channels, channels)
self.qnorm = nn.LayerNorm(channels)
self.kv = conv_nd(1, channels, channels*2, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, qx, kvx):
return checkpoint(self._forward, qx, kvx)
def _forward(self, qx, kvx):
q = self.q(self.qnorm(qx)).unsqueeze(1).repeat(1, kvx.shape[1], 1).permute(0,2,1)
kv = self.kv(self.norm(kvx.permute(0,2,1)))
qkv = torch.cat([q, kv], dim=1)
h = self.attention(qkv)
h = self.proj_out(h)
return kvx + h.permute(0,2,1)
# Next up: combine multiple embeddings given a conditioning signal into a single embedding.
class EmbeddingCombiner(nn.Module):
def __init__(self, embedding_dim, attn_blocks=3, num_attn_heads=2, cond_provided=True):
super().__init__()
block = QueryProvidedAttentionBlock if cond_provided else AttentionBlock
self.attn = nn.ModuleList([block(embedding_dim, num_attn_heads) for _ in range(attn_blocks)])
self.cond_provided = cond_provided
# x_s: (b,n,d); b=batch_sz, n=number of embeddings, d=embedding_dim
# cond: (b,d) or None
def forward(self, x_s, cond=None):
assert cond is not None and self.cond_provided or cond is None and not self.cond_provided
y = x_s
for blk in self.attn:
if self.cond_provided:
y = blk(cond, y)
else:
y = blk(y)
return y[:, 0]
if __name__ == '__main__':
x = torch.randn(2, 80, 223)
cond = torch.randn(2, 512)
encs = [AudioMiniEncoder(80, 512) for _ in range(5)]
combiner = EmbeddingCombiner(512)
e = torch.stack([e(x) for e in encs], dim=2)
print(combiner(e, cond).shape)