diff --git a/codes/models/gpt_voice/mini_encoder.py b/codes/models/gpt_voice/mini_encoder.py index 86e61d20..5d05cebb 100644 --- a/codes/models/gpt_voice/mini_encoder.py +++ b/codes/models/gpt_voice/mini_encoder.py @@ -6,7 +6,8 @@ from models.diffusion.nn import normalization, conv_nd, zero_module from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy, Upsample # Combined resnet & full-attention encoder for converting an audio clip into an embedding. -from utils.util import checkpoint +from trainer.networks import register_model +from utils.util import checkpoint, opt_get class ResBlock(nn.Module): @@ -111,15 +112,25 @@ class AudioMiniEncoder(nn.Module): for a in range(attn_blocks): attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)) self.attn = nn.Sequential(*attn) + self.dim = embedding_dim def forward(self, x): h = self.init(x) h = self.res(h) h = self.final(h) - h = self.attn(h) + h = checkpoint(self.attn, h) return h[:, :, 0] +class AudioMiniEncoderWithClassifierHead(nn.Module): + def __init__(self, classes, **kwargs): + super().__init__() + self.enc = AudioMiniEncoder(**kwargs) + self.head = nn.Linear(self.enc.dim, classes) + + def forward(self, x): + h = self.enc(x) + return self.head(h) class QueryProvidedAttentionBlock(nn.Module): @@ -188,7 +199,13 @@ class EmbeddingCombiner(nn.Module): return y[:, 0] +@register_model +def register_mini_audio_encoder_classifier(opt_net, opt): + return AudioMiniEncoderWithClassifierHead(**opt_get(opt_net, ['kwargs'], {})) + + if __name__ == '__main__': + ''' x = torch.randn(2, 80, 223) cond = torch.randn(2, 512) encs = [AudioMiniEncoder(80, 512) for _ in range(5)] @@ -197,3 +214,7 @@ if __name__ == '__main__': e = torch.stack([e(x) for e in encs], dim=2) print(combiner(e, cond).shape) + ''' + x = torch.randn(2, 80, 223) + m = AudioMiniEncoderWithClassifierHead(4, 80, 512) + print(m(x).shape) diff --git a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py index dd7800a9..925a00fd 100644 --- a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py +++ b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py @@ -123,7 +123,7 @@ class DiffusionVocoderWithRef(nn.Module): self.contextual_embedder = AudioMiniEncoder(conditioning_input_dim, time_embed_dim) self.query_gen = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1, attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) - self.embedding_combiner = EmbeddingCombiner(time_embed_dim) + self.embedding_combiner = EmbeddingCombiner(time_embed_dim, attn_blocks=1) self.input_blocks = nn.ModuleList( [ @@ -303,8 +303,8 @@ class DiffusionVocoderWithRef(nn.Module): emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels)) if self.conditioning_enabled: emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1) - emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1) - emb = self.embedding_combiner(emb, None, self.query_gen(x)) + emb2 = self.embedding_combiner(emb2, None, self.query_gen(x)) + emb = emb1 + emb2 else: emb = emb1 @@ -335,5 +335,5 @@ if __name__ == '__main__': spec = torch.randn(2,512,160) cond = torch.randn(2, 3, 80, 173) ts = torch.LongTensor([555, 556]) - model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False) - print(model(clip, ts, spec, cond, 4).shape) + model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True) + print(model(clip, ts, spec, cond, 3).shape)