From 935a4e853eb4633410132601fa3b37e5a7c1d2dc Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 27 Jan 2022 22:45:57 -0700 Subject: [PATCH] get rid of nil tokens in <2> --- codes/models/gpt_voice/unet_diffusion_tts2.py | 27 +------------------ 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/codes/models/gpt_voice/unet_diffusion_tts2.py b/codes/models/gpt_voice/unet_diffusion_tts2.py index de423e4d..513afd11 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts2.py +++ b/codes/models/gpt_voice/unet_diffusion_tts2.py @@ -135,7 +135,6 @@ class DiffusionTts(nn.Module): scale_factor=2, conditioning_inputs_provided=True, time_embed_dim_multiplier=4, - nil_guidance_fwd_proportion=.3, ): super().__init__() @@ -154,8 +153,6 @@ class DiffusionTts(nn.Module): self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.dims = dims - self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion - self.mask_token_id = num_tokens padding = 1 if kernel_size == 3 else 2 @@ -186,7 +183,7 @@ class DiffusionTts(nn.Module): for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): if ds in token_conditioning_resolutions: - token_conditioning_block = nn.Embedding(num_tokens+1, ch) + token_conditioning_block = nn.Embedding(num_tokens, ch) token_conditioning_block.weight.data.normal_(mean=0.0, std=.02) self.input_blocks.append(token_conditioning_block) token_conditioning_blocks.append(token_conditioning_block) @@ -289,23 +286,6 @@ class DiffusionTts(nn.Module): zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), ) - def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', - strict: bool = True): - # Temporary hack to allow the addition of nil-guidance token embeddings to the existing guidance embeddings. - lsd = self.state_dict() - revised = 0 - for i, blk in enumerate(self.input_blocks): - if isinstance(blk, nn.Embedding): - key = f'input_blocks.{i}.weight' - if state_dict[key].shape[0] != lsd[key].shape[0]: - t = torch.randn_like(lsd[key]) * .02 - t[:state_dict[key].shape[0]] = state_dict[key] - state_dict[key] = t - revised += 1 - print(f"Loaded experimental unet_diffusion_net with {revised} modifications.") - return super().load_state_dict(state_dict, strict) - - def forward(self, x, timesteps, tokens, conditioning_input=None): """ @@ -333,11 +313,6 @@ class DiffusionTts(nn.Module): else: emb = emb1 - # Mask out guidance tokens for un-guided diffusion. - if self.training and self.nil_guidance_fwd_proportion > 0: - token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion - tokens = torch.where(token_mask, self.mask_token_id, tokens) - h = x.type(self.dtype) for k, module in enumerate(self.input_blocks): if isinstance(module, nn.Embedding):