Rework how conditioning inputs are applied to DiffusionVocoder
This commit is contained in:
parent
b1248e7114
commit
0ee1c67ce5
|
@ -6,7 +6,8 @@ from models.diffusion.nn import normalization, conv_nd, zero_module
|
|||
from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy, Upsample
|
||||
|
||||
# Combined resnet & full-attention encoder for converting an audio clip into an embedding.
|
||||
from utils.util import checkpoint
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
|
@ -111,15 +112,25 @@ class AudioMiniEncoder(nn.Module):
|
|||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
def forward(self, x):
|
||||
h = self.init(x)
|
||||
h = self.res(h)
|
||||
h = self.final(h)
|
||||
h = self.attn(h)
|
||||
h = checkpoint(self.attn, h)
|
||||
return h[:, :, 0]
|
||||
|
||||
|
||||
class AudioMiniEncoderWithClassifierHead(nn.Module):
|
||||
def __init__(self, classes, **kwargs):
|
||||
super().__init__()
|
||||
self.enc = AudioMiniEncoder(**kwargs)
|
||||
self.head = nn.Linear(self.enc.dim, classes)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.enc(x)
|
||||
return self.head(h)
|
||||
|
||||
|
||||
class QueryProvidedAttentionBlock(nn.Module):
|
||||
|
@ -188,7 +199,13 @@ class EmbeddingCombiner(nn.Module):
|
|||
return y[:, 0]
|
||||
|
||||
|
||||
@register_model
|
||||
def register_mini_audio_encoder_classifier(opt_net, opt):
|
||||
return AudioMiniEncoderWithClassifierHead(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
x = torch.randn(2, 80, 223)
|
||||
cond = torch.randn(2, 512)
|
||||
encs = [AudioMiniEncoder(80, 512) for _ in range(5)]
|
||||
|
@ -197,3 +214,7 @@ if __name__ == '__main__':
|
|||
e = torch.stack([e(x) for e in encs], dim=2)
|
||||
|
||||
print(combiner(e, cond).shape)
|
||||
'''
|
||||
x = torch.randn(2, 80, 223)
|
||||
m = AudioMiniEncoderWithClassifierHead(4, 80, 512)
|
||||
print(m(x).shape)
|
||||
|
|
|
@ -123,7 +123,7 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
self.contextual_embedder = AudioMiniEncoder(conditioning_input_dim, time_embed_dim)
|
||||
self.query_gen = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
|
||||
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||
self.embedding_combiner = EmbeddingCombiner(time_embed_dim)
|
||||
self.embedding_combiner = EmbeddingCombiner(time_embed_dim, attn_blocks=1)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
|
@ -303,8 +303,8 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
if self.conditioning_enabled:
|
||||
emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
|
||||
emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1)
|
||||
emb = self.embedding_combiner(emb, None, self.query_gen(x))
|
||||
emb2 = self.embedding_combiner(emb2, None, self.query_gen(x))
|
||||
emb = emb1 + emb2
|
||||
else:
|
||||
emb = emb1
|
||||
|
||||
|
@ -335,5 +335,5 @@ if __name__ == '__main__':
|
|||
spec = torch.randn(2,512,160)
|
||||
cond = torch.randn(2, 3, 80, 173)
|
||||
ts = torch.LongTensor([555, 556])
|
||||
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
|
||||
print(model(clip, ts, spec, cond, 4).shape)
|
||||
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True)
|
||||
print(model(clip, ts, spec, cond, 3).shape)
|
||||
|
|
Loading…
Reference in New Issue
Block a user