diff --git a/api.py b/api.py index 2c4c336..b101518 100644 --- a/api.py +++ b/api.py @@ -76,7 +76,30 @@ def load_conditioning(clip, cond_length=132300): return mel_clip.unsqueeze(0).cuda() -def fix_autoregressive_output(codes, stop_token): +def clip_guided_generation(autoregressive_model, clip_model, conditioning_input, text_input, num_batches, stop_mel_token, + tokens_per_clip_inference=10, clip_results_to_reduce_to=8, **generation_kwargs): + """ + Uses a CLVP model trained to associate full text with **partial** audio clips to pick the best generation candidates + every few iterations. The top results are then propagated forward through the generation process. Rinse and repeat. + This is a hybrid between beam search and sampling. + """ + token_goal = tokens_per_clip_inference + finished = False + while not finished and token_goal < autoregressive_model.max_mel_tokens: + samples = [] + for b in tqdm(range(num_batches)): + codes = autoregressive_model.inference_speech(conditioning_input, text_input, **generation_kwargs) + samples.append(codes) + for batch in samples: + for i in range(batch.shape[0]): + batch[i] = fix_autoregressive_output(batch[i], stop_mel_token, complain=False) + clip_results.append(clip_model(text_input.repeat(batch.shape[0], 1), batch, return_loss=False)) + clip_results = torch.cat(clip_results, dim=0) + samples = torch.cat(samples, dim=0) + best_results = samples[torch.topk(clip_results, k=clip_results_to_reduce_to).indices] + + +def fix_autoregressive_output(codes, stop_token, complain=True): """ This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was trained on and what the autoregressive code generator creates (which has no padding or end). @@ -89,7 +112,8 @@ def fix_autoregressive_output(codes, stop_token): # Strip off the autoregressive stop token and add padding. stop_token_indices = (codes == stop_token).nonzero() if len(stop_token_indices) == 0: - print("No stop tokens found, enjoy that output of yours!") + if complain: + print("No stop tokens found, enjoy that output of yours!") return codes else: codes[stop_token_indices] = 83 @@ -136,14 +160,14 @@ class TextToSpeech: heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False, average_conditioning_embeddings=True).cpu().eval() - self.autoregressive.load_state_dict(torch.load('.models/autoregressive_diverse.pth')) + self.autoregressive.load_state_dict(torch.load('.models/autoregressive_audiobooks.pth')) self.autoregressive_for_latents = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, model_dim=1024, heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False, average_conditioning_embeddings=True).cpu().eval() - self.autoregressive_for_latents.load_state_dict(torch.load('.models/autoregressive_diverse.pth')) + self.autoregressive_for_latents.load_state_dict(torch.load('.models/autoregressive_audiobooks.pth')) self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, text_seq_len=350, text_heads=8, @@ -154,7 +178,7 @@ class TextToSpeech: self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, layer_drop=0, unconditioned_percentage=0).cpu().eval() - self.diffusion.load_state_dict(torch.load('.models/diffusion.pth')) + self.diffusion.load_state_dict(torch.load('.models/diffusion_decoder_audiobooks.pth')) self.vocoder = UnivNetGenerator().cpu() self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g']) @@ -170,7 +194,7 @@ class TextToSpeech: presets = { 'intelligible': {'temperature': .5, 'length_penalty': 2.0, 'repetition_penalty': 2.0, 'top_p': .5, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': .7, 'diffusion_temperature': .7}, 'mid': {'temperature': .7, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .7, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 1.5, 'diffusion_temperature': .8}, - 'realistic': {'temperature': .9, 'length_penalty': 1.0, 'repetition_penalty': 1.3, 'top_p': .9, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 2, 'diffusion_temperature': 1}, + 'realistic': {'temperature': 1.0, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .9, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 2, 'diffusion_temperature': 1}, } kwargs.update(presets[preset]) return self.tts(text, voice_samples, **kwargs) diff --git a/eval_multiple.py b/eval_multiple.py index 529d3bd..1113433 100644 --- a/eval_multiple.py +++ b/eval_multiple.py @@ -8,7 +8,7 @@ from utils.audio import load_audio if __name__ == '__main__': fname = 'Y:\\clips\\books2\\subset512-oco.tsv' stop_after = 128 - outpath_base = 'D:\\tmp\\tortoise-tts-eval\\diverse' + outpath_base = 'D:\\tmp\\tortoise-tts-eval\\audiobooks' outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real' os.makedirs(outpath_real, exist_ok=True) diff --git a/models/autoregressive.py b/models/autoregressive.py index 64fd451..8d5e462 100644 --- a/models/autoregressive.py +++ b/models/autoregressive.py @@ -511,7 +511,8 @@ class UnifiedVoice(nn.Module): loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_mel.mean() - def inference_speech(self, speech_conditioning_input, text_inputs, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): + def inference_speech(self, speech_conditioning_input, text_inputs, input_tokens=None, num_return_sequences=1, + max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): seq_length = self.max_mel_tokens + self.max_text_tokens + 2 if not hasattr(self, 'inference_model'): # TODO: Decouple gpt_config from this inference model. @@ -541,13 +542,23 @@ class UnifiedVoice(nn.Module): emb = torch.cat([conds, text_emb], dim=1) self.inference_model.store_mel_emb(emb) - fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device) - fake_inputs[:,-1] = self.start_mel_token + fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long, + device=text_inputs.device) + fake_inputs[:, -1] = self.start_mel_token + trunc_index = fake_inputs.shape[1] + if input_tokens is None: + inputs = fake_inputs + else: + assert num_return_sequences % input_tokens.shape[0] == 0, "The number of return sequences must be divisible by the number of input sequences" + fake_inputs = fake_inputs.repeat(num_return_sequences, 1) + input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1) + inputs = torch.cat([fake_inputs, input_tokens], dim=1) logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() - gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, - max_length=fake_inputs.shape[-1] + self.max_mel_tokens - 1, logits_processor=logits_processor, **hf_generate_kwargs) - return gen[:, fake_inputs.shape[1]:] + max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length + gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, + max_length=max_length, logits_processor=logits_processor, **hf_generate_kwargs) + return gen[:, trunc_index:] if __name__ == '__main__': diff --git a/read.py b/read.py index 639dd7d..dfcfc1d 100644 --- a/read.py +++ b/read.py @@ -32,15 +32,16 @@ if __name__ == '__main__': preselected_cond_voices = { 'emma_stone': ['voices/emma_stone/1.wav','voices/emma_stone/2.wav','voices/emma_stone/3.wav'], 'tom_hanks': ['voices/tom_hanks/1.wav','voices/tom_hanks/2.wav','voices/tom_hanks/3.wav'], + 'patrick_stewart': ['voices/patrick_stewart/1.wav','voices/patrick_stewart/2.wav','voices/patrick_stewart/3.wav','voices/patrick_stewart/4.wav'], } parser = argparse.ArgumentParser() parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt") - parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='emma_stone') - parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512) + parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='patrick_stewart') + parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=128) parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16) parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/') - parser.add_argument('-generation_preset', type=str, help='Preset to use for generation', default='intelligible') + parser.add_argument('-generation_preset', type=str, help='Preset to use for generation', default='realistic') args = parser.parse_args() os.makedirs(args.output_path, exist_ok=True) diff --git a/sweep.py b/sweep.py index 13b40fc..bc72fec 100644 --- a/sweep.py +++ b/sweep.py @@ -25,16 +25,15 @@ def permutations(args): if __name__ == '__main__': fname = 'Y:\\clips\\books2\\subset512-oco.tsv' - stop_after = 128 - outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep' + stop_after = 512 + outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep-2' outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real' arg_ranges = { - 'top_p': [.5, 1], - 'temperature': [.5, 1], - 'diffusion_temperature': [.6, 1], - 'cond_free_k': [0, 1, 4], - 'repetition_penalty': [1.0, 2.0] + 'top_p': [.8,1], + 'temperature': [.8,.9,1], + 'diffusion_temperature': [.8,1], + 'cond_free_k': [1,2,5,10], } cfgs = permutations(arg_ranges) shuffle(cfgs) @@ -56,8 +55,8 @@ if __name__ == '__main__': path = os.path.join(os.path.dirname(fname), line[1]) cond_audio = load_audio(path, 22050) torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050) - sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=256, - k=1, diffusion_iterations=70, length_penalty=1.0, **cfg) + sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=32, repetition_penalty=2.0, + k=1, diffusion_iterations=32, length_penalty=1.0, **cfg) down = torchaudio.functional.resample(sample, 24000, 22050) fout_path = os.path.join(outpath, os.path.basename(line[1])) torchaudio.save(fout_path, down.squeeze(0), 22050)