Clip diffusion inputs

This commit is contained in:
James Betker 2022-04-10 19:29:32 -06:00
parent b1ba8416ff
commit b07fb37a78

17
api.py
View File

@ -181,6 +181,7 @@ class TextToSpeech:
samples = [] samples = []
num_batches = num_autoregressive_samples // self.autoregressive_batch_size num_batches = num_autoregressive_samples // self.autoregressive_batch_size
stop_mel_token = self.autoregressive.stop_mel_token stop_mel_token = self.autoregressive.stop_mel_token
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
self.autoregressive = self.autoregressive.cuda() self.autoregressive = self.autoregressive.cuda()
for b in tqdm(range(num_batches)): for b in tqdm(range(num_batches)):
codes = self.autoregressive.inference_speech(conds, text, codes = self.autoregressive.inference_speech(conds, text,
@ -212,8 +213,20 @@ class TextToSpeech:
self.diffusion = self.diffusion.cuda() self.diffusion = self.diffusion.cuda()
self.vocoder = self.vocoder.cuda() self.vocoder = self.vocoder.cuda()
for b in range(best_results.shape[0]): for b in range(best_results.shape[0]):
code = best_results[b].unsqueeze(0) codes = best_results[b].unsqueeze(0)
mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, voice_samples, temperature=diffusion_temperature)
# Find the first occurrence of the "calm" token and trim the codes to that.
ctokens = 0
for k in range(codes.shape[-1]):
if codes[0, k] == calm_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
codes = codes[:, :k]
break
mel = do_spectrogram_diffusion(self.diffusion, diffuser, codes, voice_samples, temperature=diffusion_temperature)
wav = self.vocoder.inference(mel) wav = self.vocoder.inference(mel)
wav_candidates.append(wav.cpu()) wav_candidates.append(wav.cpu())
self.diffusion = self.diffusion.cpu() self.diffusion = self.diffusion.cpu()