forked from mrq/DL-Art-School
update unified
This commit is contained in:
parent
10fd1110be
commit
61cd351b71
|
@ -119,8 +119,8 @@ class UnifiedGptVoice(nn.Module):
|
||||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
||||||
seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
||||||
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
||||||
n_positions=seq_length-100, # -100 is a hack for backwards compatibility. TODO: remove at some point.
|
n_positions=seq_length,
|
||||||
n_ctx=seq_length-100,
|
n_ctx=seq_length,
|
||||||
n_embd=model_dim,
|
n_embd=model_dim,
|
||||||
n_layer=layers,
|
n_layer=layers,
|
||||||
n_head=heads,
|
n_head=heads,
|
||||||
|
@ -151,6 +151,13 @@ class UnifiedGptVoice(nn.Module):
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True):
|
||||||
|
# Remove the attention biases. I don't know why these are called biases because they are really just fixed attention masks forced into nn.Parameters, which are
|
||||||
|
# easily regenerated and do not need to be saved. This is a hack to allow length modifications and should be removed in the future.
|
||||||
|
filtered = dict(filter(lambda i: not i[0].endswith('.attn.bias'), state_dict.items()))
|
||||||
|
assert len(filtered) == len(state_dict) - len(self.gpt.h)
|
||||||
|
return super().load_state_dict(filtered, strict)
|
||||||
|
|
||||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||||
inp = F.pad(input, (1,0), value=start_token)
|
inp = F.pad(input, (1,0), value=start_token)
|
||||||
tar = F.pad(input, (0,1), value=stop_token)
|
tar = F.pad(input, (0,1), value=stop_token)
|
||||||
|
@ -298,7 +305,7 @@ class UnifiedGptVoice(nn.Module):
|
||||||
emb = torch.cat([cond, text_emb], dim=1)
|
emb = torch.cat([cond, text_emb], dim=1)
|
||||||
self.inference_model.store_mel_emb(emb)
|
self.inference_model.store_mel_emb(emb)
|
||||||
|
|
||||||
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
fake_inputs = torch.full((emb.shape[0], emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||||
fake_inputs[:,-1] = self.start_mel_token
|
fake_inputs[:,-1] = self.start_mel_token
|
||||||
|
|
||||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||||
|
|
|
@ -88,7 +88,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
||||||
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts.yml')
|
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts.yml')
|
||||||
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
||||||
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts_libri_all_and_hifi_no_unsupervised\\models\\4000_gpt.pth')
|
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts_libri_all_and_hifi_no_unsupervised\\models\\17500_gpt.pth')
|
||||||
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
||||||
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
|
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
|
||||||
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='libri_test')
|
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='libri_test')
|
||||||
|
@ -100,7 +100,7 @@ if __name__ == '__main__':
|
||||||
with open(args.opt_gpt_tts, mode='r') as f:
|
with open(args.opt_gpt_tts, mode='r') as f:
|
||||||
gpt_opt = yaml.load(f, Loader=Loader)
|
gpt_opt = yaml.load(f, Loader=Loader)
|
||||||
gpt_opt['networks'][args.gpt_tts_model_name]['kwargs']['checkpointing'] = False # Required for beam search
|
gpt_opt['networks'][args.gpt_tts_model_name]['kwargs']['checkpointing'] = False # Required for beam search
|
||||||
gpt = load_model_from_config(preloaded_options=gpt_opt, model_name=args.gpt_tts_model_name, also_load_savepoint=False, load_path=args.gpt_tts_model_path)
|
gpt = load_model_from_config(preloaded_options=gpt_opt, model_name=args.gpt_tts_model_name, also_load_savepoint=False, load_path=args.gpt_tts_model_path, strict_load=False)
|
||||||
|
|
||||||
print("Loading data..")
|
print("Loading data..")
|
||||||
tokenizer = CharacterTokenizer()
|
tokenizer = CharacterTokenizer()
|
||||||
|
|
|
@ -467,7 +467,7 @@ def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, nor
|
||||||
|
|
||||||
|
|
||||||
Loader, Dumper = OrderedYaml()
|
Loader, Dumper = OrderedYaml()
|
||||||
def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load_savepoint=True, load_path=None, preloaded_options=None):
|
def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load_savepoint=True, load_path=None, preloaded_options=None, strict_load=True):
|
||||||
if preloaded_options is not None:
|
if preloaded_options is not None:
|
||||||
opt = preloaded_options
|
opt = preloaded_options
|
||||||
else:
|
else:
|
||||||
|
@ -486,5 +486,5 @@ def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load
|
||||||
load_path = opt['path'][f'pretrain_model_{model_name}']
|
load_path = opt['path'][f'pretrain_model_{model_name}']
|
||||||
if load_path is not None:
|
if load_path is not None:
|
||||||
print(f"Loading from {load_path}")
|
print(f"Loading from {load_path}")
|
||||||
model.load_state_dict(torch.load(load_path))
|
model.load_state_dict(torch.load(load_path), strict=strict_load)
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user