forked from mrq/DL-Art-School
remove timesteps from cond calculation
This commit is contained in:
parent
668876799d
commit
38802a96c8
|
@ -432,7 +432,7 @@ class DiffusionTts(nn.Module):
|
||||||
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
||||||
code_emb = self.code_embedding(tokens).permute(0,2,1)
|
code_emb = self.code_embedding(tokens).permute(0,2,1)
|
||||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||||
cond_time_emb = timestep_embedding(timesteps, code_emb.shape[1])
|
cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove.
|
||||||
cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||||
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1))
|
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1))
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user