Rework how conditioning inputs are applied to DiffusionVocoder

This commit is contained in:
James Betker 2021-10-24 09:08:58 -06:00
parent b1248e7114
commit 0ee1c67ce5
2 changed files with 28 additions and 7 deletions

View File

@ -6,7 +6,8 @@ from models.diffusion.nn import normalization, conv_nd, zero_module
from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy, Upsample
# Combined resnet & full-attention encoder for converting an audio clip into an embedding.
from utils.util import checkpoint
from trainer.networks import register_model
from utils.util import checkpoint, opt_get
class ResBlock(nn.Module):
@ -111,15 +112,25 @@ class AudioMiniEncoder(nn.Module):
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
def forward(self, x):
h = self.init(x)
h = self.res(h)
h = self.final(h)
h = self.attn(h)
h = checkpoint(self.attn, h)
return h[:, :, 0]
class AudioMiniEncoderWithClassifierHead(nn.Module):
def __init__(self, classes, **kwargs):
super().__init__()
self.enc = AudioMiniEncoder(**kwargs)
self.head = nn.Linear(self.enc.dim, classes)
def forward(self, x):
h = self.enc(x)
return self.head(h)
class QueryProvidedAttentionBlock(nn.Module):
@ -188,7 +199,13 @@ class EmbeddingCombiner(nn.Module):
return y[:, 0]
@register_model
def register_mini_audio_encoder_classifier(opt_net, opt):
return AudioMiniEncoderWithClassifierHead(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
'''
x = torch.randn(2, 80, 223)
cond = torch.randn(2, 512)
encs = [AudioMiniEncoder(80, 512) for _ in range(5)]
@ -197,3 +214,7 @@ if __name__ == '__main__':
e = torch.stack([e(x) for e in encs], dim=2)
print(combiner(e, cond).shape)
'''
x = torch.randn(2, 80, 223)
m = AudioMiniEncoderWithClassifierHead(4, 80, 512)
print(m(x).shape)

View File

@ -123,7 +123,7 @@ class DiffusionVocoderWithRef(nn.Module):
self.contextual_embedder = AudioMiniEncoder(conditioning_input_dim, time_embed_dim)
self.query_gen = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.embedding_combiner = EmbeddingCombiner(time_embed_dim)
self.embedding_combiner = EmbeddingCombiner(time_embed_dim, attn_blocks=1)
self.input_blocks = nn.ModuleList(
[
@ -303,8 +303,8 @@ class DiffusionVocoderWithRef(nn.Module):
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.conditioning_enabled:
emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1)
emb = self.embedding_combiner(emb, None, self.query_gen(x))
emb2 = self.embedding_combiner(emb2, None, self.query_gen(x))
emb = emb1 + emb2
else:
emb = emb1
@ -335,5 +335,5 @@ if __name__ == '__main__':
spec = torch.randn(2,512,160)
cond = torch.randn(2, 3, 80, 173)
ts = torch.LongTensor([555, 556])
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
print(model(clip, ts, spec, cond, 4).shape)
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True)
print(model(clip, ts, spec, cond, 3).shape)