forked from mrq/DL-Art-School
Get diffusion_dvae functional
This commit is contained in:
parent
e513052fca
commit
0382660159
|
@ -5,6 +5,7 @@ from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, Res
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
|
||||||
from models.vqvae.vqvae import Quantize
|
from models.vqvae.vqvae import Quantize
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
import models.gpt_voice.my_dvae as mdvae
|
import models.gpt_voice.my_dvae as mdvae
|
||||||
|
@ -120,6 +121,9 @@ class DiffusionDVAE(nn.Module):
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_embed_dim),
|
linear(time_embed_dim, time_embed_dim),
|
||||||
)
|
)
|
||||||
|
self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim)
|
||||||
|
self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim)
|
||||||
|
self.embedding_combiner = EmbeddingCombiner(time_embed_dim)
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
@ -258,7 +262,7 @@ class DiffusionDVAE(nn.Module):
|
||||||
self.middle_block.apply(convert_module_to_f32)
|
self.middle_block.apply(convert_module_to_f32)
|
||||||
self.output_blocks.apply(convert_module_to_f32)
|
self.output_blocks.apply(convert_module_to_f32)
|
||||||
|
|
||||||
def forward(self, x, timesteps, spectrogram):
|
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None):
|
||||||
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
||||||
|
|
||||||
# Compute DVAE portion first.
|
# Compute DVAE portion first.
|
||||||
|
@ -275,9 +279,17 @@ class DiffusionDVAE(nn.Module):
|
||||||
spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)]
|
spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)]
|
||||||
convergence_fns = list(self.convergence_convs)
|
convergence_fns = list(self.convergence_convs)
|
||||||
|
|
||||||
# The rest is the diffusion vocoder, built as a standard U-net. spec_h is gradually fed into the encoder.
|
# Timestep embeddings and conditioning signals are combined using a small transformer.
|
||||||
hs = []
|
hs = []
|
||||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||||
|
if conditioning_inputs is not None:
|
||||||
|
emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
|
||||||
|
emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1)
|
||||||
|
else:
|
||||||
|
emb = emb1.unsqueeze(1)
|
||||||
|
emb = self.embedding_combiner(emb, self.query_gen(spec_hs[0]))
|
||||||
|
|
||||||
|
# The rest is the diffusion vocoder, built as a standard U-net. spec_h is gradually fed into the encoder.
|
||||||
next_spec = spec_hs.pop(0)
|
next_spec = spec_hs.pop(0)
|
||||||
next_convergence_fn = convergence_fns.pop(0)
|
next_convergence_fn = convergence_fns.pop(0)
|
||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
|
@ -311,6 +323,7 @@ def register_unet_diffusion_dvae(opt_net, opt):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
clip = torch.randn(1, 1, 81920)
|
clip = torch.randn(1, 1, 81920)
|
||||||
spec = torch.randn(1, 80, 416)
|
spec = torch.randn(1, 80, 416)
|
||||||
|
cond = torch.randn(1, 5, 80, 200)
|
||||||
ts = torch.LongTensor([555])
|
ts = torch.LongTensor([555])
|
||||||
model = DiffusionDVAE(32, 2)
|
model = DiffusionDVAE(32, 2)
|
||||||
print(model(clip, ts, spec).shape)
|
print(model(clip, ts, spec, cond)[0].shape)
|
||||||
|
|
122
codes/models/gpt_voice/mini_encoder.py
Normal file
122
codes/models/gpt_voice/mini_encoder.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
from models.diffusion.nn import normalization, conv_nd, zero_module
|
||||||
|
from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy
|
||||||
|
from models.gpt_voice.my_dvae import ResBlock
|
||||||
|
|
||||||
|
|
||||||
|
# Combined resnet & full-attention encoder for converting an audio clip into an embedding.
|
||||||
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
class AudioMiniEncoder(nn.Module):
|
||||||
|
def __init__(self, spec_dim, embedding_dim, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, dropout=0):
|
||||||
|
super().__init__()
|
||||||
|
self.init = nn.Sequential(
|
||||||
|
conv_nd(1, spec_dim, 128, 3, padding=1)
|
||||||
|
)
|
||||||
|
ch = 128
|
||||||
|
res = []
|
||||||
|
for l in range(2):
|
||||||
|
for r in range(resnet_blocks):
|
||||||
|
res.append(ResBlock(ch, dropout, dims=1))
|
||||||
|
res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=2))
|
||||||
|
ch *= 2
|
||||||
|
self.res = nn.Sequential(*res)
|
||||||
|
self.final = nn.Sequential(
|
||||||
|
normalization(ch),
|
||||||
|
nn.SiLU(),
|
||||||
|
conv_nd(1, ch, embedding_dim, 1)
|
||||||
|
)
|
||||||
|
attn = []
|
||||||
|
for a in range(attn_blocks):
|
||||||
|
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
||||||
|
self.attn = nn.Sequential(*attn)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.init(x)
|
||||||
|
h = self.res(h)
|
||||||
|
h = self.final(h)
|
||||||
|
h = self.attn(h)
|
||||||
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QueryProvidedAttentionBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
An attention block that provides a separate signal for the query vs the keys/parameters.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
num_heads=1,
|
||||||
|
num_head_channels=-1,
|
||||||
|
use_new_attention_order=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
if num_head_channels == -1:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
channels % num_head_channels == 0
|
||||||
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||||
|
self.num_heads = channels // num_head_channels
|
||||||
|
self.norm = normalization(channels)
|
||||||
|
self.q = nn.Linear(channels, channels)
|
||||||
|
self.qnorm = nn.LayerNorm(channels)
|
||||||
|
self.kv = conv_nd(1, channels, channels*2, 1)
|
||||||
|
if use_new_attention_order:
|
||||||
|
# split qkv before split heads
|
||||||
|
self.attention = QKVAttention(self.num_heads)
|
||||||
|
else:
|
||||||
|
# split heads before split qkv
|
||||||
|
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||||
|
|
||||||
|
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||||
|
|
||||||
|
def forward(self, qx, kvx):
|
||||||
|
return checkpoint(self._forward, qx, kvx)
|
||||||
|
|
||||||
|
def _forward(self, qx, kvx):
|
||||||
|
q = self.q(self.qnorm(qx)).unsqueeze(1).repeat(1, kvx.shape[1], 1).permute(0,2,1)
|
||||||
|
kv = self.kv(self.norm(kvx.permute(0,2,1)))
|
||||||
|
qkv = torch.cat([q, kv], dim=1)
|
||||||
|
h = self.attention(qkv)
|
||||||
|
h = self.proj_out(h)
|
||||||
|
return kvx + h.permute(0,2,1)
|
||||||
|
|
||||||
|
|
||||||
|
# Next up: combine multiple embeddings given a conditioning signal into a single embedding.
|
||||||
|
class EmbeddingCombiner(nn.Module):
|
||||||
|
def __init__(self, embedding_dim, attn_blocks=3, num_attn_heads=2, cond_provided=True):
|
||||||
|
super().__init__()
|
||||||
|
block = QueryProvidedAttentionBlock if cond_provided else AttentionBlock
|
||||||
|
self.attn = nn.ModuleList([block(embedding_dim, num_attn_heads) for _ in range(attn_blocks)])
|
||||||
|
self.cond_provided = cond_provided
|
||||||
|
|
||||||
|
# x_s: (b,n,d); b=batch_sz, n=number of embeddings, d=embedding_dim
|
||||||
|
# cond: (b,d) or None
|
||||||
|
def forward(self, x_s, cond=None):
|
||||||
|
assert cond is not None and self.cond_provided or cond is None and not self.cond_provided
|
||||||
|
y = x_s
|
||||||
|
for blk in self.attn:
|
||||||
|
if self.cond_provided:
|
||||||
|
y = blk(cond, y)
|
||||||
|
else:
|
||||||
|
y = blk(y)
|
||||||
|
return y[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
x = torch.randn(2, 80, 223)
|
||||||
|
cond = torch.randn(2, 512)
|
||||||
|
encs = [AudioMiniEncoder(80, 512) for _ in range(5)]
|
||||||
|
combiner = EmbeddingCombiner(512)
|
||||||
|
|
||||||
|
e = torch.stack([e(x) for e in encs], dim=2)
|
||||||
|
|
||||||
|
print(combiner(e, cond).shape)
|
Loading…
Reference in New Issue
Block a user