Re-instate autocasting
This commit is contained in:
parent
34ee32a90e
commit
c375287db9
|
@ -431,51 +431,56 @@ class DiffusionTts(nn.Module):
|
|||
unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1)
|
||||
unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1)
|
||||
|
||||
orig_x_shape = x.shape[-1]
|
||||
cm = ceil_multiple(x.shape[-1], 2048)
|
||||
if cm != 0:
|
||||
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||
with autocast(x.device.type):
|
||||
orig_x_shape = x.shape[-1]
|
||||
cm = ceil_multiple(x.shape[-1], 2048)
|
||||
if cm != 0:
|
||||
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||
if tokens is not None:
|
||||
tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1])))
|
||||
|
||||
hs = []
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
cond_emb = self.contextual_embedder(conditioning_input)
|
||||
if tokens is not None:
|
||||
tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1])))
|
||||
|
||||
hs = []
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
cond_emb = self.contextual_embedder(conditioning_input)
|
||||
if tokens is not None:
|
||||
# Mask out guidance tokens for un-guided diffusion.
|
||||
if self.training and self.nil_guidance_fwd_proportion > 0:
|
||||
token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True)
|
||||
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
||||
code_emb = self.code_embedding(tokens).permute(0,2,1)
|
||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||
cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove.
|
||||
cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1))
|
||||
else:
|
||||
code_emb = cond_emb.unsqueeze(-1)
|
||||
if self.enable_unaligned_inputs:
|
||||
code_emb = self.conditioning_encoder(code_emb, context=unaligned_h)
|
||||
else:
|
||||
code_emb = self.conditioning_encoder(code_emb)
|
||||
|
||||
time_emb = time_emb.float()
|
||||
h = x
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
if isinstance(module, nn.Conv1d):
|
||||
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
||||
h = h + h_tok
|
||||
# Mask out guidance tokens for un-guided diffusion.
|
||||
if self.training and self.nil_guidance_fwd_proportion > 0:
|
||||
token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True)
|
||||
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
||||
code_emb = self.code_embedding(tokens).permute(0,2,1)
|
||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||
cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove.
|
||||
cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1))
|
||||
else:
|
||||
# First block has autocast disabled to allow a high precision signal to be properly vectorized.
|
||||
code_emb = cond_emb.unsqueeze(-1)
|
||||
if self.enable_unaligned_inputs:
|
||||
code_emb = self.conditioning_encoder(code_emb, context=unaligned_h)
|
||||
else:
|
||||
code_emb = self.conditioning_encoder(code_emb)
|
||||
|
||||
first = True
|
||||
time_emb = time_emb.float()
|
||||
h = x
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
if isinstance(module, nn.Conv1d):
|
||||
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
||||
h = h + h_tok
|
||||
else:
|
||||
with autocast(x.device.type, enabled=not first):
|
||||
# First block has autocast disabled to allow a high precision signal to be properly vectorized.
|
||||
h = module(h, time_emb)
|
||||
hs.append(h)
|
||||
first = False
|
||||
h = self.middle_block(h, time_emb)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, time_emb)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, time_emb)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, time_emb)
|
||||
|
||||
# Last block also has autocast disabled for high-precision outputs.
|
||||
h = h.float()
|
||||
out = self.out(h)
|
||||
return out[:, :, :orig_x_shape]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user