update unified

This commit is contained in:
James Betker 2022-01-06 09:48:11 -07:00
parent 10fd1110be
commit 61cd351b71
3 changed files with 14 additions and 7 deletions

View File

@ -119,8 +119,8 @@ class UnifiedGptVoice(nn.Module):
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
n_positions=seq_length-100, # -100 is a hack for backwards compatibility. TODO: remove at some point.
n_ctx=seq_length-100,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
@ -151,6 +151,13 @@ class UnifiedGptVoice(nn.Module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True):
# Remove the attention biases. I don't know why these are called biases because they are really just fixed attention masks forced into nn.Parameters, which are
# easily regenerated and do not need to be saved. This is a hack to allow length modifications and should be removed in the future.
filtered = dict(filter(lambda i: not i[0].endswith('.attn.bias'), state_dict.items()))
assert len(filtered) == len(state_dict) - len(self.gpt.h)
return super().load_state_dict(filtered, strict)
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
@ -298,7 +305,7 @@ class UnifiedGptVoice(nn.Module):
emb = torch.cat([cond, text_emb], dim=1)
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.start_mel_token
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,

View File

@ -88,7 +88,7 @@ if __name__ == '__main__':
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts.yml')
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts_libri_all_and_hifi_no_unsupervised\\models\\4000_gpt.pth')
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts_libri_all_and_hifi_no_unsupervised\\models\\17500_gpt.pth')
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='libri_test')
@ -100,7 +100,7 @@ if __name__ == '__main__':
with open(args.opt_gpt_tts, mode='r') as f:
gpt_opt = yaml.load(f, Loader=Loader)
gpt_opt['networks'][args.gpt_tts_model_name]['kwargs']['checkpointing'] = False # Required for beam search
gpt = load_model_from_config(preloaded_options=gpt_opt, model_name=args.gpt_tts_model_name, also_load_savepoint=False, load_path=args.gpt_tts_model_path)
gpt = load_model_from_config(preloaded_options=gpt_opt, model_name=args.gpt_tts_model_name, also_load_savepoint=False, load_path=args.gpt_tts_model_path, strict_load=False)
print("Loading data..")
tokenizer = CharacterTokenizer()

View File

@ -467,7 +467,7 @@ def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, nor
Loader, Dumper = OrderedYaml()
def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load_savepoint=True, load_path=None, preloaded_options=None):
def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load_savepoint=True, load_path=None, preloaded_options=None, strict_load=True):
if preloaded_options is not None:
opt = preloaded_options
else:
@ -486,5 +486,5 @@ def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load
load_path = opt['path'][f'pretrain_model_{model_name}']
if load_path is not None:
print(f"Loading from {load_path}")
model.load_state_dict(torch.load(load_path))
model.load_state_dict(torch.load(load_path), strict=strict_load)
return model