From 34ee32a90ed868848e27cf2d69302ae48afd3739 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 24 Feb 2022 21:53:51 -0700 Subject: [PATCH] get rid of autocasting in tts7 --- codes/models/gpt_voice/unet_diffusion_tts7.py | 85 +++++++++---------- 1 file changed, 40 insertions(+), 45 deletions(-) diff --git a/codes/models/gpt_voice/unet_diffusion_tts7.py b/codes/models/gpt_voice/unet_diffusion_tts7.py index 90bc12af..32b91d3f 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts7.py +++ b/codes/models/gpt_voice/unet_diffusion_tts7.py @@ -431,56 +431,51 @@ class DiffusionTts(nn.Module): unaligned_h = self.unaligned_embedder(unaligned_input).permute(0,2,1) unaligned_h = self.unaligned_encoder(unaligned_h).permute(0,2,1) - with autocast(x.device.type): - orig_x_shape = x.shape[-1] - cm = ceil_multiple(x.shape[-1], 2048) - if cm != 0: - pc = (cm-x.shape[-1])/x.shape[-1] - x = F.pad(x, (0,cm-x.shape[-1])) - if tokens is not None: - tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1]))) - - hs = [] - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - cond_emb = self.contextual_embedder(conditioning_input) + orig_x_shape = x.shape[-1] + cm = ceil_multiple(x.shape[-1], 2048) + if cm != 0: + pc = (cm-x.shape[-1])/x.shape[-1] + x = F.pad(x, (0,cm-x.shape[-1])) if tokens is not None: - # Mask out guidance tokens for un-guided diffusion. - if self.training and self.nil_guidance_fwd_proportion > 0: - token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True) - tokens = torch.where(token_mask, self.mask_token_id, tokens) - code_emb = self.code_embedding(tokens).permute(0,2,1) - cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) - cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove. - cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) - code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1)) - else: - code_emb = cond_emb.unsqueeze(-1) - if self.enable_unaligned_inputs: - code_emb = self.conditioning_encoder(code_emb, context=unaligned_h) - else: - code_emb = self.conditioning_encoder(code_emb) + tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1]))) - first = True - time_emb = time_emb.float() - h = x - for k, module in enumerate(self.input_blocks): - if isinstance(module, nn.Conv1d): - h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') - h = h + h_tok - else: - with autocast(x.device.type, enabled=not first): - # First block has autocast disabled to allow a high precision signal to be properly vectorized. - h = module(h, time_emb) - hs.append(h) - first = False - h = self.middle_block(h, time_emb) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) + hs = [] + time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + cond_emb = self.contextual_embedder(conditioning_input) + if tokens is not None: + # Mask out guidance tokens for un-guided diffusion. + if self.training and self.nil_guidance_fwd_proportion > 0: + token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True) + tokens = torch.where(token_mask, self.mask_token_id, tokens) + code_emb = self.code_embedding(tokens).permute(0,2,1) + cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) + cond_time_emb = timestep_embedding(torch.zeros_like(timesteps), code_emb.shape[1]) # This was something I was doing (adding timesteps into this computation), but removed on second thought. TODO: completely remove. + cond_time_emb = cond_time_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]) + code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb, cond_time_emb], dim=1)) + else: + code_emb = cond_emb.unsqueeze(-1) + if self.enable_unaligned_inputs: + code_emb = self.conditioning_encoder(code_emb, context=unaligned_h) + else: + code_emb = self.conditioning_encoder(code_emb) + + time_emb = time_emb.float() + h = x + for k, module in enumerate(self.input_blocks): + if isinstance(module, nn.Conv1d): + h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') + h = h + h_tok + else: + # First block has autocast disabled to allow a high precision signal to be properly vectorized. h = module(h, time_emb) + hs.append(h) + h = self.middle_block(h, time_emb) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, time_emb) # Last block also has autocast disabled for high-precision outputs. - h = h.float() out = self.out(h) return out[:, :, :orig_x_shape]