From 65a546c4d7602da329345cf6245c03051a774159 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 5 Feb 2022 16:00:14 -0700 Subject: [PATCH] Fix for tts6 --- codes/models/gpt_voice/unet_diffusion_tts6.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/codes/models/gpt_voice/unet_diffusion_tts6.py b/codes/models/gpt_voice/unet_diffusion_tts6.py index 13a7ab7a..2ad52943 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts6.py +++ b/codes/models/gpt_voice/unet_diffusion_tts6.py @@ -364,23 +364,6 @@ class DiffusionTts(nn.Module): zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), ) - def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', - strict: bool = True): - # Temporary hack to allow the addition of nil-guidance token embeddings to the existing guidance embeddings. - lsd = self.state_dict() - revised = 0 - for i, blk in enumerate(self.input_blocks): - if isinstance(blk, nn.Embedding): - key = f'input_blocks.{i}.weight' - if state_dict[key].shape[0] != lsd[key].shape[0]: - t = torch.randn_like(lsd[key]) * .02 - t[:state_dict[key].shape[0]] = state_dict[key] - state_dict[key] = t - revised += 1 - print(f"Loaded experimental unet_diffusion_net with {revised} modifications.") - return super().load_state_dict(state_dict, strict) - - def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None): """