get rid of autocasting in tts7

This commit is contained in:
James Betker 2022-02-24 21:53:51 -07:00
parent f458f5d8f1
commit 34ee32a90e

View File

@ -431,7 +431,6 @@ class DiffusionTts(nn.Module):
unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1) unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1)
unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1) unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1)
with autocast(x.device.type):
orig_x_shape = x.shape[-1] orig_x_shape = x.shape[-1]
cm = ceil_multiple(x.shape[-1], 2048) cm = ceil_multiple(x.shape[-1], 2048)
if cm != 0: if cm != 0:
@ -461,7 +460,6 @@ class DiffusionTts(nn.Module):
else: else:
code_emb = self.conditioning_encoder(code_emb) code_emb = self.conditioning_encoder(code_emb)
first = True
time_emb = time_emb.float() time_emb = time_emb.float()
h = x h = x
for k, module in enumerate(self.input_blocks): for k, module in enumerate(self.input_blocks):
@ -469,18 +467,15 @@ class DiffusionTts(nn.Module):
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
h = h + h_tok h = h + h_tok
else: else:
with autocast(x.device.type, enabled=not first):
# First block has autocast disabled to allow a high precision signal to be properly vectorized. # First block has autocast disabled to allow a high precision signal to be properly vectorized.
h = module(h, time_emb) h = module(h, time_emb)
hs.append(h) hs.append(h)
first = False
h = self.middle_block(h, time_emb) h = self.middle_block(h, time_emb)
for module in self.output_blocks: for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1) h = torch.cat([h, hs.pop()], dim=1)
h = module(h, time_emb) h = module(h, time_emb)
# Last block also has autocast disabled for high-precision outputs. # Last block also has autocast disabled for high-precision outputs.
h = h.float()
out = self.out(h) out = self.out(h)
return out[:, :, :orig_x_shape] return out[:, :, :orig_x_shape]