update unified
This commit is contained in:
parent
10fd1110be
commit
61cd351b71
|
@ -119,8 +119,8 @@ class UnifiedGptVoice(nn.Module):
|
|||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
||||
seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
||||
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
||||
n_positions=seq_length-100, # -100 is a hack for backwards compatibility. TODO: remove at some point.
|
||||
n_ctx=seq_length-100,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
|
@ -151,6 +151,13 @@ class UnifiedGptVoice(nn.Module):
|
|||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True):
|
||||
# Remove the attention biases. I don't know why these are called biases because they are really just fixed attention masks forced into nn.Parameters, which are
|
||||
# easily regenerated and do not need to be saved. This is a hack to allow length modifications and should be removed in the future.
|
||||
filtered = dict(filter(lambda i: not i[0].endswith('.attn.bias'), state_dict.items()))
|
||||
assert len(filtered) == len(state_dict) - len(self.gpt.h)
|
||||
return super().load_state_dict(filtered, strict)
|
||||
|
||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1,0), value=start_token)
|
||||
tar = F.pad(input, (0,1), value=stop_token)
|
||||
|
@ -298,7 +305,7 @@ class UnifiedGptVoice(nn.Module):
|
|||
emb = torch.cat([cond, text_emb], dim=1)
|
||||
self.inference_model.store_mel_emb(emb)
|
||||
|
||||
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:,-1] = self.start_mel_token
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||
|
|
|
@ -88,7 +88,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
||||
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts.yml')
|
||||
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
||||
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts_libri_all_and_hifi_no_unsupervised\\models\\4000_gpt.pth')
|
||||
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts_libri_all_and_hifi_no_unsupervised\\models\\17500_gpt.pth')
|
||||
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
||||
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
|
||||
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='libri_test')
|
||||
|
@ -100,7 +100,7 @@ if __name__ == '__main__':
|
|||
with open(args.opt_gpt_tts, mode='r') as f:
|
||||
gpt_opt = yaml.load(f, Loader=Loader)
|
||||
gpt_opt['networks'][args.gpt_tts_model_name]['kwargs']['checkpointing'] = False # Required for beam search
|
||||
gpt = load_model_from_config(preloaded_options=gpt_opt, model_name=args.gpt_tts_model_name, also_load_savepoint=False, load_path=args.gpt_tts_model_path)
|
||||
gpt = load_model_from_config(preloaded_options=gpt_opt, model_name=args.gpt_tts_model_name, also_load_savepoint=False, load_path=args.gpt_tts_model_path, strict_load=False)
|
||||
|
||||
print("Loading data..")
|
||||
tokenizer = CharacterTokenizer()
|
||||
|
|
|
@ -467,7 +467,7 @@ def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, nor
|
|||
|
||||
|
||||
Loader, Dumper = OrderedYaml()
|
||||
def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load_savepoint=True, load_path=None, preloaded_options=None):
|
||||
def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load_savepoint=True, load_path=None, preloaded_options=None, strict_load=True):
|
||||
if preloaded_options is not None:
|
||||
opt = preloaded_options
|
||||
else:
|
||||
|
@ -486,5 +486,5 @@ def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load
|
|||
load_path = opt['path'][f'pretrain_model_{model_name}']
|
||||
if load_path is not None:
|
||||
print(f"Loading from {load_path}")
|
||||
model.load_state_dict(torch.load(load_path))
|
||||
model.load_state_dict(torch.load(load_path), strict=strict_load)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue
Block a user