From 935a4e853eb4633410132601fa3b37e5a7c1d2dc Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Thu, 27 Jan 2022 22:45:57 -0700
Subject: [PATCH] get rid of nil tokens in <2>

---
 codes/models/gpt_voice/unet_diffusion_tts2.py | 27 +------------------
 1 file changed, 1 insertion(+), 26 deletions(-)

diff --git a/codes/models/gpt_voice/unet_diffusion_tts2.py b/codes/models/gpt_voice/unet_diffusion_tts2.py
index de423e4d..513afd11 100644
--- a/codes/models/gpt_voice/unet_diffusion_tts2.py
+++ b/codes/models/gpt_voice/unet_diffusion_tts2.py
@@ -135,7 +135,6 @@ class DiffusionTts(nn.Module):
             scale_factor=2,
             conditioning_inputs_provided=True,
             time_embed_dim_multiplier=4,
-            nil_guidance_fwd_proportion=.3,
     ):
         super().__init__()
 
@@ -154,8 +153,6 @@ class DiffusionTts(nn.Module):
         self.num_head_channels = num_head_channels
         self.num_heads_upsample = num_heads_upsample
         self.dims = dims
-        self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
-        self.mask_token_id = num_tokens
 
         padding = 1 if kernel_size == 3 else 2
 
@@ -186,7 +183,7 @@ class DiffusionTts(nn.Module):
 
         for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
             if ds in token_conditioning_resolutions:
-                token_conditioning_block = nn.Embedding(num_tokens+1, ch)
+                token_conditioning_block = nn.Embedding(num_tokens, ch)
                 token_conditioning_block.weight.data.normal_(mean=0.0, std=.02)
                 self.input_blocks.append(token_conditioning_block)
                 token_conditioning_blocks.append(token_conditioning_block)
@@ -289,23 +286,6 @@ class DiffusionTts(nn.Module):
             zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
         )
 
-    def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
-                        strict: bool = True):
-        # Temporary hack to allow the addition of nil-guidance token embeddings to the existing guidance embeddings.
-        lsd = self.state_dict()
-        revised = 0
-        for i, blk in enumerate(self.input_blocks):
-            if isinstance(blk, nn.Embedding):
-                key = f'input_blocks.{i}.weight'
-                if state_dict[key].shape[0] != lsd[key].shape[0]:
-                    t = torch.randn_like(lsd[key]) * .02
-                    t[:state_dict[key].shape[0]] = state_dict[key]
-                    state_dict[key] = t
-                    revised += 1
-        print(f"Loaded experimental unet_diffusion_net with {revised} modifications.")
-        return super().load_state_dict(state_dict, strict)
-
-
 
     def forward(self, x, timesteps, tokens, conditioning_input=None):
         """
@@ -333,11 +313,6 @@ class DiffusionTts(nn.Module):
         else:
             emb = emb1
 
-        # Mask out guidance tokens for un-guided diffusion.
-        if self.training and self.nil_guidance_fwd_proportion > 0:
-            token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
-            tokens = torch.where(token_mask, self.mask_token_id, tokens)
-
         h = x.type(self.dtype)
         for k, module in enumerate(self.input_blocks):
             if isinstance(module, nn.Embedding):