Rework how conditioning inputs are applied to DiffusionVocoder
This commit is contained in:
parent
b1248e7114
commit
0ee1c67ce5
|
@ -6,7 +6,8 @@ from models.diffusion.nn import normalization, conv_nd, zero_module
|
||||||
from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy, Upsample
|
from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy, Upsample
|
||||||
|
|
||||||
# Combined resnet & full-attention encoder for converting an audio clip into an embedding.
|
# Combined resnet & full-attention encoder for converting an audio clip into an embedding.
|
||||||
from utils.util import checkpoint
|
from trainer.networks import register_model
|
||||||
|
from utils.util import checkpoint, opt_get
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
class ResBlock(nn.Module):
|
||||||
|
@ -111,15 +112,25 @@ class AudioMiniEncoder(nn.Module):
|
||||||
for a in range(attn_blocks):
|
for a in range(attn_blocks):
|
||||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
|
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
|
||||||
self.attn = nn.Sequential(*attn)
|
self.attn = nn.Sequential(*attn)
|
||||||
|
self.dim = embedding_dim
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = self.init(x)
|
h = self.init(x)
|
||||||
h = self.res(h)
|
h = self.res(h)
|
||||||
h = self.final(h)
|
h = self.final(h)
|
||||||
h = self.attn(h)
|
h = checkpoint(self.attn, h)
|
||||||
return h[:, :, 0]
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class AudioMiniEncoderWithClassifierHead(nn.Module):
|
||||||
|
def __init__(self, classes, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.enc = AudioMiniEncoder(**kwargs)
|
||||||
|
self.head = nn.Linear(self.enc.dim, classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.enc(x)
|
||||||
|
return self.head(h)
|
||||||
|
|
||||||
|
|
||||||
class QueryProvidedAttentionBlock(nn.Module):
|
class QueryProvidedAttentionBlock(nn.Module):
|
||||||
|
@ -188,7 +199,13 @@ class EmbeddingCombiner(nn.Module):
|
||||||
return y[:, 0]
|
return y[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_mini_audio_encoder_classifier(opt_net, opt):
|
||||||
|
return AudioMiniEncoderWithClassifierHead(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
'''
|
||||||
x = torch.randn(2, 80, 223)
|
x = torch.randn(2, 80, 223)
|
||||||
cond = torch.randn(2, 512)
|
cond = torch.randn(2, 512)
|
||||||
encs = [AudioMiniEncoder(80, 512) for _ in range(5)]
|
encs = [AudioMiniEncoder(80, 512) for _ in range(5)]
|
||||||
|
@ -197,3 +214,7 @@ if __name__ == '__main__':
|
||||||
e = torch.stack([e(x) for e in encs], dim=2)
|
e = torch.stack([e(x) for e in encs], dim=2)
|
||||||
|
|
||||||
print(combiner(e, cond).shape)
|
print(combiner(e, cond).shape)
|
||||||
|
'''
|
||||||
|
x = torch.randn(2, 80, 223)
|
||||||
|
m = AudioMiniEncoderWithClassifierHead(4, 80, 512)
|
||||||
|
print(m(x).shape)
|
||||||
|
|
|
@ -123,7 +123,7 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
self.contextual_embedder = AudioMiniEncoder(conditioning_input_dim, time_embed_dim)
|
self.contextual_embedder = AudioMiniEncoder(conditioning_input_dim, time_embed_dim)
|
||||||
self.query_gen = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
|
self.query_gen = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
|
||||||
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||||
self.embedding_combiner = EmbeddingCombiner(time_embed_dim)
|
self.embedding_combiner = EmbeddingCombiner(time_embed_dim, attn_blocks=1)
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
@ -303,8 +303,8 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||||
if self.conditioning_enabled:
|
if self.conditioning_enabled:
|
||||||
emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
|
emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
|
||||||
emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1)
|
emb2 = self.embedding_combiner(emb2, None, self.query_gen(x))
|
||||||
emb = self.embedding_combiner(emb, None, self.query_gen(x))
|
emb = emb1 + emb2
|
||||||
else:
|
else:
|
||||||
emb = emb1
|
emb = emb1
|
||||||
|
|
||||||
|
@ -335,5 +335,5 @@ if __name__ == '__main__':
|
||||||
spec = torch.randn(2,512,160)
|
spec = torch.randn(2,512,160)
|
||||||
cond = torch.randn(2, 3, 80, 173)
|
cond = torch.randn(2, 3, 80, 173)
|
||||||
ts = torch.LongTensor([555, 556])
|
ts = torch.LongTensor([555, 556])
|
||||||
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
|
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True)
|
||||||
print(model(clip, ts, spec, cond, 4).shape)
|
print(model(clip, ts, spec, cond, 3).shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user