diff --git a/codes/models/gpt_voice/gpt_asr_hf.py b/codes/models/gpt_voice/gpt_asr_hf.py index 67e509f4..a120824a 100644 --- a/codes/models/gpt_voice/gpt_asr_hf.py +++ b/codes/models/gpt_voice/gpt_asr_hf.py @@ -145,15 +145,17 @@ class GPT2InferenceModel(GPT2PreTrainedModel): if input_ids.shape[1] != 1: text_inputs = input_ids[:, mel_len:] text_emb = self.transformer.get_input_embeddings()(text_inputs) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device)) + if self.text_pos_embedding is not None: + text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device)) if self.cached_mel_emb.shape[0] != text_emb.shape[0]: mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0) else: mel_emb = self.cached_mel_emb emb = torch.cat([mel_emb, text_emb], dim=1) else: - emb = self.transformer.get_input_embeddings()(input_ids) + \ - self.text_pos_embedding(torch.tensor(attention_mask.shape[1]-mel_len, device=attention_mask.device)).unsqueeze(0).unsqueeze(0) + emb = self.transformer.get_input_embeddings()(input_ids) + if self.text_pos_embedding is not None: + emb = emb + self.text_pos_embedding(torch.tensor(attention_mask.shape[1]-mel_len, device=attention_mask.device)).unsqueeze(0).unsqueeze(0) transformer_outputs = self.transformer( inputs_embeds=emb, diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py index 6c28dd25..31da0343 100644 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ b/codes/models/gpt_voice/gpt_tts_hf.py @@ -115,12 +115,18 @@ class GptTtsHf(nn.Module): def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8, repetition_penalty=1): if not hasattr(self, 'inference_model'): - self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head) + self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head) text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN) text_emb = self.text_embedding(text_inputs) + # Format conditioning inputs properly. + if len(cond_inputs.shape) == 3: + cond_inputs = cond_inputs.unsqueeze(1) # Format a single conditioning input as a set of {1} + if cond_inputs.shape[-1] > self.max_conditioning_length: + cond_inputs = cond_inputs[:,:,:,:self.max_conditioning_length] + conds = [] for k in range(cond_inputs.shape[1]): conds.append(self.conditioning_encoder(cond_inputs[:, k])) diff --git a/codes/scripts/audio/gen/use_gpt_tts.py b/codes/scripts/audio/gen/use_gpt_tts.py index 798cbbf0..50973a26 100644 --- a/codes/scripts/audio/gen/use_gpt_tts.py +++ b/codes/scripts/audio/gen/use_gpt_tts.py @@ -20,6 +20,7 @@ def do_vocoding(dvae, vocoder, diffuser, codes, cond=None, plot_spec=False): return +# Loads multiple conditioning files at random from a folder. def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length=44100): candidates = find_files_of_type('img', path, qualifier=is_audio_file)[0] # Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates. @@ -37,6 +38,17 @@ def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length return torch.stack(related_mels, dim=0).unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda() +def load_conditioning(path, sample_rate=22050, cond_length=44100): + rel_clip = load_audio(path, sample_rate) + gap = rel_clip.shape[-1] - cond_length + if gap < 0: + rel_clip = F.pad(rel_clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + rel_clip = rel_clip[:, rand_start:rand_start + cond_length] + mel_clip = wav_to_mel(rel_clip.unsqueeze(0)).squeeze(0) + return mel_clip.unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda() + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -46,10 +58,9 @@ if __name__ == '__main__': parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae') parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml') parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt') - parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\48000_gpt.pth') - parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.") - parser.add_argument('-cond_path', type=str, help='Folder containing conditioning samples.', default='Y:\\clips\\podcasts-0\\8816_20210511-Pay Taxes Less Frequently_ We\'re Interested') - parser.add_argument('-num_cond', type=int, help='Number of conditioning samples to load.', default=3) + parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts_no_pos\\models\\28500_gpt_ema.pth') + parser.add_argument('-text', type=str, help='Text to speak.', default="Please set this in the courier drone when we dock.") + parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav') args = parser.parse_args() print("Loading GPT TTS..") @@ -60,7 +71,7 @@ if __name__ == '__main__': print("Loading data..") text = torch.IntTensor(text_to_sequence(args.text, ['english_cleaners'])).unsqueeze(0).cuda() - conds, cond_wav = load_conditioning_candidates(args.cond_path, args.num_cond) + conds, cond_wav = load_conditioning(args.cond_path) print("Performing GPT inference..") codes = gpt.inference(text, conds, num_beams=32, repetition_penalty=10.0)