From 0edc98f6c467a3658eb844804201dfd8c9c352f7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 16 Oct 2021 09:02:01 -0600 Subject: [PATCH] Throw out the idea of conditioning on discrete codes. Oh well :( --- codes/models/gpt_voice/lucidrains_dvae.py | 6 ++++-- .../unet_diffusion_vocoder_with_ref.py | 18 +++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 04ca00a6..48b19b2e 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -190,8 +190,10 @@ class DiscreteVAE(nn.Module): arrange = 'b (h w) d -> b d h w' kwargs = {'h': h, 'w': w} image_embeds = rearrange(image_embeds, arrange, **kwargs) - images = self.decoder(image_embeds) - return images + images = [image_embeds] + for layer in self.decoder: + images.append(layer(images[-1])) + return images[-1], images[-2] def infer(self, img): img = self.norm(img) diff --git a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py index 8e2909fa..140afb89 100644 --- a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py +++ b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py @@ -11,9 +11,9 @@ from utils.util import get_mask_from_lengths class DiscreteSpectrogramConditioningBlock(nn.Module): - def __init__(self, discrete_codes, channels): + def __init__(self, dvae_channels, channels): super().__init__() - self.emb = nn.Embedding(discrete_codes, channels) + self.emb = nn.Conv1d(dvae_channels, channels, kernel_size=1) self.norm = normalization(channels) self.act = nn.SiLU() self.intg = nn.Sequential(nn.Conv1d(channels*2, channels*2, kernel_size=1), @@ -30,11 +30,10 @@ class DiscreteSpectrogramConditioningBlock(nn.Module): :param x: bxcxS waveform latent :param codes: bxN discrete codes, N <= S """ - def forward(self, x, codes): - _, c, S = x.shape - b, N = codes.shape - assert N <= S - emb = self.emb(codes).permute(0,2,1) + def forward(self, x, dvae_in): + b, c, S = x.shape + _, q, N = dvae_in.shape + emb = self.emb(dvae_in) emb = nn.functional.interpolate(emb, size=(S,), mode='nearest') together = torch.cat([self.act(self.norm(x)), emb], dim=1) together = self.intg(together) @@ -77,7 +76,7 @@ class DiffusionVocoderWithRef(nn.Module): model_channels, in_channels=1, out_channels=2, # mean and variance - discrete_codes=8192, + discrete_codes=512, dropout=0, # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), @@ -339,7 +338,8 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt): # Test for ~4 second audio clip at 22050Hz if __name__ == '__main__': clip = torch.randn(2, 1, 40960) - spec = torch.randint(8192, (2, 40,)) + #spec = torch.randint(8192, (2, 40,)) + spec = torch.randn(8,512,160) cond = torch.randn(2, 3, 80, 173) ts = torch.LongTensor([555, 556]) model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)