remove timesteps from cond calculation

This commit is contained in:
James Betker 2022-02-21 12:32:21 -07:00
parent 668876799d
commit 38802a96c8

View File

@ -432,7 +432,7 @@ class DiffusionTts(nn.Module):
tokens = torch.where(token_mask, self.mask_token_id, tokens) tokens = torch.where(token_mask, self.mask_token_id, tokens)
code_emb = self.code_embedding(tokens).permute(0,2,1) code_emb = self.code_embedding(tokens).permute(0,2,1)
cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
cond_time_emb = timestep_embedding(timesteps, code_emb.shape[1]) cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove.
cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1)) code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1))
else: else: