implement clip-guided generation (and never use it...)

This commit is contained in:
James Betker 2022-04-14 21:50:57 -06:00
parent 60d363fc60
commit 979ff6e65e
5 changed files with 60 additions and 25 deletions

34
api.py
View File

@ -76,7 +76,30 @@ def load_conditioning(clip, cond_length=132300):
return mel_clip.unsqueeze(0).cuda() return mel_clip.unsqueeze(0).cuda()
def fix_autoregressive_output(codes, stop_token): def clip_guided_generation(autoregressive_model, clip_model, conditioning_input, text_input, num_batches, stop_mel_token,
tokens_per_clip_inference=10, clip_results_to_reduce_to=8, **generation_kwargs):
"""
Uses a CLVP model trained to associate full text with **partial** audio clips to pick the best generation candidates
every few iterations. The top results are then propagated forward through the generation process. Rinse and repeat.
This is a hybrid between beam search and sampling.
"""
token_goal = tokens_per_clip_inference
finished = False
while not finished and token_goal < autoregressive_model.max_mel_tokens:
samples = []
for b in tqdm(range(num_batches)):
codes = autoregressive_model.inference_speech(conditioning_input, text_input, **generation_kwargs)
samples.append(codes)
for batch in samples:
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token, complain=False)
clip_results.append(clip_model(text_input.repeat(batch.shape[0], 1), batch, return_loss=False))
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=clip_results_to_reduce_to).indices]
def fix_autoregressive_output(codes, stop_token, complain=True):
""" """
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
trained on and what the autoregressive code generator creates (which has no padding or end). trained on and what the autoregressive code generator creates (which has no padding or end).
@ -89,6 +112,7 @@ def fix_autoregressive_output(codes, stop_token):
# Strip off the autoregressive stop token and add padding. # Strip off the autoregressive stop token and add padding.
stop_token_indices = (codes == stop_token).nonzero() stop_token_indices = (codes == stop_token).nonzero()
if len(stop_token_indices) == 0: if len(stop_token_indices) == 0:
if complain:
print("No stop tokens found, enjoy that output of yours!") print("No stop tokens found, enjoy that output of yours!")
return codes return codes
else: else:
@ -136,14 +160,14 @@ class TextToSpeech:
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
train_solo_embeddings=False, train_solo_embeddings=False,
average_conditioning_embeddings=True).cpu().eval() average_conditioning_embeddings=True).cpu().eval()
self.autoregressive.load_state_dict(torch.load('.models/autoregressive_diverse.pth')) self.autoregressive.load_state_dict(torch.load('.models/autoregressive_audiobooks.pth'))
self.autoregressive_for_latents = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, self.autoregressive_for_latents = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
model_dim=1024, model_dim=1024,
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
train_solo_embeddings=False, train_solo_embeddings=False,
average_conditioning_embeddings=True).cpu().eval() average_conditioning_embeddings=True).cpu().eval()
self.autoregressive_for_latents.load_state_dict(torch.load('.models/autoregressive_diverse.pth')) self.autoregressive_for_latents.load_state_dict(torch.load('.models/autoregressive_audiobooks.pth'))
self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
text_seq_len=350, text_heads=8, text_seq_len=350, text_heads=8,
@ -154,7 +178,7 @@ class TextToSpeech:
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
layer_drop=0, unconditioned_percentage=0).cpu().eval() layer_drop=0, unconditioned_percentage=0).cpu().eval()
self.diffusion.load_state_dict(torch.load('.models/diffusion.pth')) self.diffusion.load_state_dict(torch.load('.models/diffusion_decoder_audiobooks.pth'))
self.vocoder = UnivNetGenerator().cpu() self.vocoder = UnivNetGenerator().cpu()
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g']) self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
@ -170,7 +194,7 @@ class TextToSpeech:
presets = { presets = {
'intelligible': {'temperature': .5, 'length_penalty': 2.0, 'repetition_penalty': 2.0, 'top_p': .5, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': .7, 'diffusion_temperature': .7}, 'intelligible': {'temperature': .5, 'length_penalty': 2.0, 'repetition_penalty': 2.0, 'top_p': .5, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': .7, 'diffusion_temperature': .7},
'mid': {'temperature': .7, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .7, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 1.5, 'diffusion_temperature': .8}, 'mid': {'temperature': .7, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .7, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 1.5, 'diffusion_temperature': .8},
'realistic': {'temperature': .9, 'length_penalty': 1.0, 'repetition_penalty': 1.3, 'top_p': .9, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 2, 'diffusion_temperature': 1}, 'realistic': {'temperature': 1.0, 'length_penalty': 1.0, 'repetition_penalty': 2.0, 'top_p': .9, 'diffusion_iterations': 100, 'cond_free': True, 'cond_free_k': 2, 'diffusion_temperature': 1},
} }
kwargs.update(presets[preset]) kwargs.update(presets[preset])
return self.tts(text, voice_samples, **kwargs) return self.tts(text, voice_samples, **kwargs)

View File

@ -8,7 +8,7 @@ from utils.audio import load_audio
if __name__ == '__main__': if __name__ == '__main__':
fname = 'Y:\\clips\\books2\\subset512-oco.tsv' fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
stop_after = 128 stop_after = 128
outpath_base = 'D:\\tmp\\tortoise-tts-eval\\diverse' outpath_base = 'D:\\tmp\\tortoise-tts-eval\\audiobooks'
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real' outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
os.makedirs(outpath_real, exist_ok=True) os.makedirs(outpath_real, exist_ok=True)

View File

@ -511,7 +511,8 @@ class UnifiedVoice(nn.Module):
loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_mel.mean() return loss_mel.mean()
def inference_speech(self, speech_conditioning_input, text_inputs, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): def inference_speech(self, speech_conditioning_input, text_inputs, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2 seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if not hasattr(self, 'inference_model'): if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model. # TODO: Decouple gpt_config from this inference model.
@ -541,13 +542,23 @@ class UnifiedVoice(nn.Module):
emb = torch.cat([conds, text_emb], dim=1) emb = torch.cat([conds, text_emb], dim=1)
self.inference_model.store_mel_emb(emb) self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device) fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long,
device=text_inputs.device)
fake_inputs[:, -1] = self.start_mel_token fake_inputs[:, -1] = self.start_mel_token
trunc_index = fake_inputs.shape[1]
if input_tokens is None:
inputs = fake_inputs
else:
assert num_return_sequences % input_tokens.shape[0] == 0, "The number of return sequences must be divisible by the number of input sequences"
fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
inputs = torch.cat([fake_inputs, input_tokens], dim=1)
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
max_length=fake_inputs.shape[-1] + self.max_mel_tokens - 1, logits_processor=logits_processor, **hf_generate_kwargs) gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
return gen[:, fake_inputs.shape[1]:] max_length=max_length, logits_processor=logits_processor, **hf_generate_kwargs)
return gen[:, trunc_index:]
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -32,15 +32,16 @@ if __name__ == '__main__':
preselected_cond_voices = { preselected_cond_voices = {
'emma_stone': ['voices/emma_stone/1.wav','voices/emma_stone/2.wav','voices/emma_stone/3.wav'], 'emma_stone': ['voices/emma_stone/1.wav','voices/emma_stone/2.wav','voices/emma_stone/3.wav'],
'tom_hanks': ['voices/tom_hanks/1.wav','voices/tom_hanks/2.wav','voices/tom_hanks/3.wav'], 'tom_hanks': ['voices/tom_hanks/1.wav','voices/tom_hanks/2.wav','voices/tom_hanks/3.wav'],
'patrick_stewart': ['voices/patrick_stewart/1.wav','voices/patrick_stewart/2.wav','voices/patrick_stewart/3.wav','voices/patrick_stewart/4.wav'],
} }
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt") parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='emma_stone') parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='patrick_stewart')
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512) parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=128)
parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16) parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/') parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/')
parser.add_argument('-generation_preset', type=str, help='Preset to use for generation', default='intelligible') parser.add_argument('-generation_preset', type=str, help='Preset to use for generation', default='realistic')
args = parser.parse_args() args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True) os.makedirs(args.output_path, exist_ok=True)

View File

@ -25,16 +25,15 @@ def permutations(args):
if __name__ == '__main__': if __name__ == '__main__':
fname = 'Y:\\clips\\books2\\subset512-oco.tsv' fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
stop_after = 128 stop_after = 512
outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep' outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep-2'
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real' outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
arg_ranges = { arg_ranges = {
'top_p': [.5, 1], 'top_p': [.8,1],
'temperature': [.5, 1], 'temperature': [.8,.9,1],
'diffusion_temperature': [.6, 1], 'diffusion_temperature': [.8,1],
'cond_free_k': [0, 1, 4], 'cond_free_k': [1,2,5,10],
'repetition_penalty': [1.0, 2.0]
} }
cfgs = permutations(arg_ranges) cfgs = permutations(arg_ranges)
shuffle(cfgs) shuffle(cfgs)
@ -56,8 +55,8 @@ if __name__ == '__main__':
path = os.path.join(os.path.dirname(fname), line[1]) path = os.path.join(os.path.dirname(fname), line[1])
cond_audio = load_audio(path, 22050) cond_audio = load_audio(path, 22050)
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050) torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=256, sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=32, repetition_penalty=2.0,
k=1, diffusion_iterations=70, length_penalty=1.0, **cfg) k=1, diffusion_iterations=32, length_penalty=1.0, **cfg)
down = torchaudio.functional.resample(sample, 24000, 22050) down = torchaudio.functional.resample(sample, 24000, 22050)
fout_path = os.path.join(outpath, os.path.basename(line[1])) fout_path = os.path.join(outpath, os.path.basename(line[1]))
torchaudio.save(fout_path, down.squeeze(0), 22050) torchaudio.save(fout_path, down.squeeze(0), 22050)