From 61cd351b715fe3a32ffbd9ab13ed0e4357d27ca4 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Thu, 6 Jan 2022 09:48:11 -0700
Subject: [PATCH] update unified

---
 codes/models/gpt_voice/unified_voice.py | 13 ++++++++++---
 codes/scripts/audio/gen/use_gpt_tts.py  |  4 ++--
 codes/utils/util.py                     |  4 ++--
 3 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py
index b5096079..84809bdc 100644
--- a/codes/models/gpt_voice/unified_voice.py
+++ b/codes/models/gpt_voice/unified_voice.py
@@ -119,8 +119,8 @@ class UnifiedGptVoice(nn.Module):
         self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
         seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
         self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
-                                     n_positions=seq_length-100, # -100 is a hack for backwards compatibility. TODO: remove at some point.
-                                     n_ctx=seq_length-100,
+                                     n_positions=seq_length,
+                                     n_ctx=seq_length,
                                      n_embd=model_dim,
                                      n_layer=layers,
                                      n_head=heads,
@@ -151,6 +151,13 @@ class UnifiedGptVoice(nn.Module):
             if module.padding_idx is not None:
                 module.weight.data[module.padding_idx].zero_()
 
+    def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True):
+        # Remove the attention biases. I don't know why these are called biases because they are really just fixed attention masks forced into nn.Parameters, which are
+        # easily regenerated and do not need to be saved. This is a hack to allow length modifications and should be removed in the future.
+        filtered = dict(filter(lambda i: not i[0].endswith('.attn.bias'), state_dict.items()))
+        assert len(filtered) == len(state_dict) - len(self.gpt.h)
+        return super().load_state_dict(filtered, strict)
+
     def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
         inp = F.pad(input, (1,0), value=start_token)
         tar = F.pad(input, (0,1), value=stop_token)
@@ -298,7 +305,7 @@ class UnifiedGptVoice(nn.Module):
         emb = torch.cat([cond, text_emb], dim=1)
         self.inference_model.store_mel_emb(emb)
 
-        fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
+        fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
         fake_inputs[:,-1] = self.start_mel_token
 
         gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
diff --git a/codes/scripts/audio/gen/use_gpt_tts.py b/codes/scripts/audio/gen/use_gpt_tts.py
index 6c03a184..58eaa63b 100644
--- a/codes/scripts/audio/gen/use_gpt_tts.py
+++ b/codes/scripts/audio/gen/use_gpt_tts.py
@@ -88,7 +88,7 @@ if __name__ == '__main__':
     parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
     parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts.yml')
     parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
-    parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts_libri_all_and_hifi_no_unsupervised\\models\\4000_gpt.pth')
+    parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts_libri_all_and_hifi_no_unsupervised\\models\\17500_gpt.pth')
     parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
     parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
     parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='libri_test')
@@ -100,7 +100,7 @@ if __name__ == '__main__':
     with open(args.opt_gpt_tts, mode='r') as f:
         gpt_opt = yaml.load(f, Loader=Loader)
     gpt_opt['networks'][args.gpt_tts_model_name]['kwargs']['checkpointing'] = False  # Required for beam search
-    gpt = load_model_from_config(preloaded_options=gpt_opt, model_name=args.gpt_tts_model_name, also_load_savepoint=False, load_path=args.gpt_tts_model_path)
+    gpt = load_model_from_config(preloaded_options=gpt_opt, model_name=args.gpt_tts_model_name, also_load_savepoint=False, load_path=args.gpt_tts_model_path, strict_load=False)
 
     print("Loading data..")
     tokenizer = CharacterTokenizer()
diff --git a/codes/utils/util.py b/codes/utils/util.py
index d43faa8f..75b6cf09 100644
--- a/codes/utils/util.py
+++ b/codes/utils/util.py
@@ -467,7 +467,7 @@ def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, nor
 
 
 Loader, Dumper = OrderedYaml()
-def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load_savepoint=True, load_path=None, preloaded_options=None):
+def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load_savepoint=True, load_path=None, preloaded_options=None, strict_load=True):
     if preloaded_options is not None:
         opt = preloaded_options
     else:
@@ -486,5 +486,5 @@ def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load
         load_path = opt['path'][f'pretrain_model_{model_name}']
     if load_path is not None:
         print(f"Loading from {load_path}")
-        model.load_state_dict(torch.load(load_path))
+        model.load_state_dict(torch.load(load_path), strict=strict_load)
     return model