From 38802a96c83a38aec1221c7357271384366961af Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 21 Feb 2022 12:32:21 -0700 Subject: [PATCH] remove timesteps from cond calculation --- codes/models/gpt_voice/unet_diffusion_tts7.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codes/models/gpt_voice/unet_diffusion_tts7.py b/codes/models/gpt_voice/unet_diffusion_tts7.py index aad53597..77b76a36 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts7.py +++ b/codes/models/gpt_voice/unet_diffusion_tts7.py @@ -432,7 +432,7 @@ class DiffusionTts(nn.Module): tokens = torch.where(token_mask, self.mask_token_id, tokens) code_emb = self.code_embedding(tokens).permute(0,2,1) cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) - cond_time_emb = timestep_embedding(timesteps, code_emb.shape[1]) + cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove. cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1)) else: