From d9f8f92840cd9b5082e0666812f5e4a01e0b85a6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 1 Mar 2022 15:46:04 -0700 Subject: [PATCH] Codified fp16 --- codes/models/gpt_voice/unet_diffusion_tts7.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codes/models/gpt_voice/unet_diffusion_tts7.py b/codes/models/gpt_voice/unet_diffusion_tts7.py index bf05e125..f99ce990 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts7.py +++ b/codes/models/gpt_voice/unet_diffusion_tts7.py @@ -219,7 +219,6 @@ class DiffusionTts(nn.Module): self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample - self.dtype = torch.float16 if use_fp16 else torch.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample @@ -229,6 +228,7 @@ class DiffusionTts(nn.Module): self.super_sampling_enabled = super_sampling self.super_sampling_max_noising_factor = super_sampling_max_noising_factor self.unconditioned_percentage = unconditioned_percentage + self.enable_fp16 = use_fp16 padding = 1 if kernel_size == 3 else 2 time_embed_dim = model_channels * time_embed_dim_multiplier @@ -431,7 +431,7 @@ class DiffusionTts(nn.Module): lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest') x = torch.cat([x, lr_input], dim=1) - with autocast(x.device.type): + with autocast(x.device.type, enabled=self.enable_fp16): orig_x_shape = x.shape[-1] cm = ceil_multiple(x.shape[-1], 2048) if cm != 0: @@ -482,7 +482,7 @@ class DiffusionTts(nn.Module): h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') h = h + h_tok else: - with autocast(x.device.type, enabled=not first): + with autocast(x.device.type, enabled=self.enable_fp16 and not first): # First block has autocast disabled to allow a high precision signal to be properly vectorized. h = module(h, time_emb) hs.append(h)