get rid of autocasting in tts7
This commit is contained in:
parent
f458f5d8f1
commit
34ee32a90e
|
@ -431,7 +431,6 @@ class DiffusionTts(nn.Module):
|
||||||
unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1)
|
unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1)
|
||||||
unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1)
|
unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1)
|
||||||
|
|
||||||
with autocast(x.device.type):
|
|
||||||
orig_x_shape = x.shape[-1]
|
orig_x_shape = x.shape[-1]
|
||||||
cm = ceil_multiple(x.shape[-1], 2048)
|
cm = ceil_multiple(x.shape[-1], 2048)
|
||||||
if cm != 0:
|
if cm != 0:
|
||||||
|
@ -461,7 +460,6 @@ class DiffusionTts(nn.Module):
|
||||||
else:
|
else:
|
||||||
code_emb = self.conditioning_encoder(code_emb)
|
code_emb = self.conditioning_encoder(code_emb)
|
||||||
|
|
||||||
first = True
|
|
||||||
time_emb = time_emb.float()
|
time_emb = time_emb.float()
|
||||||
h = x
|
h = x
|
||||||
for k, module in enumerate(self.input_blocks):
|
for k, module in enumerate(self.input_blocks):
|
||||||
|
@ -469,18 +467,15 @@ class DiffusionTts(nn.Module):
|
||||||
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
||||||
h = h + h_tok
|
h = h + h_tok
|
||||||
else:
|
else:
|
||||||
with autocast(x.device.type, enabled=not first):
|
|
||||||
# First block has autocast disabled to allow a high precision signal to be properly vectorized.
|
# First block has autocast disabled to allow a high precision signal to be properly vectorized.
|
||||||
h = module(h, time_emb)
|
h = module(h, time_emb)
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
first = False
|
|
||||||
h = self.middle_block(h, time_emb)
|
h = self.middle_block(h, time_emb)
|
||||||
for module in self.output_blocks:
|
for module in self.output_blocks:
|
||||||
h = torch.cat([h, hs.pop()], dim=1)
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
h = module(h, time_emb)
|
h = module(h, time_emb)
|
||||||
|
|
||||||
# Last block also has autocast disabled for high-precision outputs.
|
# Last block also has autocast disabled for high-precision outputs.
|
||||||
h = h.float()
|
|
||||||
out = self.out(h)
|
out = self.out(h)
|
||||||
return out[:, :, :orig_x_shape]
|
return out[:, :, :orig_x_shape]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user