From 3a9d1c53eacf70cd4f0a5731ff936bbdd2b5775a Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 26 Oct 2021 10:46:33 -0600 Subject: [PATCH] Rework conditioning inputs provided --- .../gpt_voice/unet_diffusion_vocoder_with_ref.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py index afd3375c..55ccd34c 100644 --- a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py +++ b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py @@ -283,7 +283,7 @@ class DiffusionVocoderWithRef(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps, spectrogram, conditioning_inputs=None, num_conditioning_signals=None): + def forward(self, x, timesteps, spectrogram, conditioning_input=None): """ Apply the model to an input batch. @@ -294,14 +294,12 @@ class DiffusionVocoderWithRef(nn.Module): """ assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. if self.conditioning_enabled: - assert conditioning_inputs is not None - assert num_conditioning_signals is not None + assert conditioning_input is not None hs = [] emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels)) if self.conditioning_enabled: - #emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1) - emb2 = self.contextual_embedder(conditioning_inputs[:, 0]) + emb2 = self.contextual_embedder(conditioning_input) emb = emb1 + emb2 else: emb = emb1 @@ -331,7 +329,7 @@ if __name__ == '__main__': clip = torch.randn(2, 1, 40960) #spec = torch.randint(8192, (2, 40,)) spec = torch.randn(2,512,160) - cond = torch.randn(2, 3, 80, 173) + cond = torch.randn(2, 80, 173) ts = torch.LongTensor([555, 556]) model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8) - print(model(clip, ts, spec, cond, 3).shape) + print(model(clip, ts, spec, cond).shape)