From 38802a96c83a38aec1221c7357271384366961af Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Mon, 21 Feb 2022 12:32:21 -0700
Subject: [PATCH] remove timesteps from cond calculation

---
 codes/models/gpt_voice/unet_diffusion_tts7.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/codes/models/gpt_voice/unet_diffusion_tts7.py b/codes/models/gpt_voice/unet_diffusion_tts7.py
index aad53597..77b76a36 100644
--- a/codes/models/gpt_voice/unet_diffusion_tts7.py
+++ b/codes/models/gpt_voice/unet_diffusion_tts7.py
@@ -432,7 +432,7 @@ class DiffusionTts(nn.Module):
                     tokens = torch.where(token_mask, self.mask_token_id, tokens)
                 code_emb = self.code_embedding(tokens).permute(0,2,1)
                 cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
-                cond_time_emb = timestep_embedding(timesteps, code_emb.shape[1])
+                cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1])  # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove.
                 cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
                 code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1))
             else: