un-hardcoded input output sampling rates (changing them "works" but leads to wrong audio, naturally)

This commit is contained in:
mrq 2023-02-07 18:34:29 +00:00
parent 5f934c5feb
commit 793515772a
3 changed files with 27 additions and 20 deletions

16
app.py
View File

@ -27,7 +27,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
if voice == "microphone":
if mic_audio is None:
raise gr.Error("Please provide audio from mic when choosing `microphone` as a voice input")
mic = load_audio(mic_audio, 22050)
mic = load_audio(mic_audio, tts.input_sample_rate)
voice_samples, conditioning_latents = [mic], None
else:
progress(0, desc="Loading voice...")
@ -105,14 +105,14 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
}
os.makedirs(f'{outdir}/candidate_{j}', exist_ok=True)
torchaudio.save(f'{outdir}/candidate_{j}/result_{line}.wav', audio, 24000)
torchaudio.save(f'{outdir}/candidate_{j}/result_{line}.wav', audio, tts.output_sample_rate)
else:
audio = gen.squeeze(0).cpu()
audio_cache[f"result_{line}.wav"] = {
'audio': audio,
'text': cut_text,
}
torchaudio.save(f'{outdir}/result_{line}.wav', audio, 24000)
torchaudio.save(f'{outdir}/result_{line}.wav', audio, tts.output_sample_rate)
output_voice = None
if len(texts) > 1:
@ -126,7 +126,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
audio_clips.append(audio)
audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/combined_{candidate}.wav', audio, 24000)
torchaudio.save(f'{outdir}/combined_{candidate}.wav', audio, tts.output_sample_rate)
audio = audio.squeeze(0).cpu()
audio_cache[f'combined_{candidate}.wav'] = {
@ -143,7 +143,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
output_voice = gen
if output_voice is not None:
output_voice = (24000, output_voice.numpy())
output_voice = (tts.output_sample_rate, output_voice.numpy())
info = {
'text': text,
@ -179,7 +179,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, seed, candidate
metadata.save()
if sample_voice is not None:
sample_voice = (22050, sample_voice.squeeze().cpu().numpy())
sample_voice = (tts.input_sample_rate, sample_voice.squeeze().cpu().numpy())
print(f"Generation took {info['time']} seconds, saved to '{outdir}'\n")
@ -514,6 +514,8 @@ if __name__ == "__main__":
args = parser.parse_args()
print("Initializating TorToiSe...")
tts = TextToSpeech(minor_optimizations=not args.low_vram)
tts = TextToSpeech(
minor_optimizations=not args.low_vram,
)
main()

View File

@ -114,7 +114,7 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
def format_conditioning(clip, cond_length=132300, device='cuda'):
def format_conditioning(clip, cond_length=132300, device='cuda', sampling_rate=22050):
"""
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
"""
@ -124,7 +124,7 @@ def format_conditioning(clip, cond_length=132300, device='cuda'):
elif gap > 0:
rand_start = random.randint(0, gap)
clip = clip[:, rand_start:rand_start + cond_length]
mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
mel_clip = TorchMelSpectrogram(sampling_rate=sample_rate)(clip.unsqueeze(0)).squeeze(0)
return mel_clip.unsqueeze(0).to(device)
@ -158,12 +158,12 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
return codes
def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True, progress=None, desc=None, sampler="P"):
def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True, progress=None, desc=None, sampler="P", input_sample_rate=22050, output_sample_rate=24000):
"""
Uses the specified diffusion model to convert discrete codes into a spectrogram.
"""
with torch.no_grad():
output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_seq_len = latents.shape[1] * 4 * output_sample_rate // input_sample_rate # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, False)
@ -214,7 +214,7 @@ class TextToSpeech:
Main entry point into Tortoise.
"""
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True):
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True, input_sample_rate=22050, output_sample_rate=24000):
"""
Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -234,7 +234,10 @@ class TextToSpeech:
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.input_sample_rate = input_sample_rate
self.output_sample_rate = output_sample_rate
self.minor_optimizations = minor_optimizations
self.models_dir = models_dir
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None or autoregressive_batch_size == 0 else autoregressive_batch_size
self.enable_redaction = enable_redaction
@ -306,7 +309,7 @@ class TextToSpeech:
if not isinstance(voice_samples, list):
voice_samples = [voice_samples]
for vs in voice_samples:
auto_conds.append(format_conditioning(vs, device=self.device))
auto_conds.append(format_conditioning(vs, device=self.device, sampling_rate=self.input_sample_rate))
auto_conds = torch.stack(auto_conds, dim=1)
@ -315,7 +318,8 @@ class TextToSpeech:
samples = [] # resample in its own pass to make things easier
for sample in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
samples.append(torchaudio.functional.resample(sample, 22050, 24000))
#samples.append(torchaudio.functional.resample(sample, 22050, 24000))
samples.append(torchaudio.functional.resample(sample, self.input_sample_rate, self.output_sample_rate))
if chunk_size is None:
for sample in tqdm_override(samples, verbose=verbose and len(samples) > 1, progress=progress if len(samples) > 1 else None, desc="Calculating size of best fit..."):
@ -582,7 +586,8 @@ class TextToSpeech:
break
mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, diffusion_conditioning,
temperature=diffusion_temperature, verbose=verbose, progress=progress, desc="Transforming autoregressive outputs into audio..", sampler=diffusion_sampler)
temperature=diffusion_temperature, verbose=verbose, progress=progress, desc="Transforming autoregressive outputs into audio..", sampler=diffusion_sampler,
input_sample_rate=self.input_sample_rate, output_sample_rate=self.output_sample_rate)
wav = self.vocoder.inference(mel)
wav_candidates.append(wav.cpu())
@ -592,7 +597,7 @@ class TextToSpeech:
def potentially_redact(clip, text):
if self.enable_redaction:
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
return self.aligner.redact(clip.squeeze(1), text, self.output_sample_rate).unsqueeze(1)
return clip
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]

View File

@ -97,7 +97,7 @@ def get_voices(extra_voice_dirs=[]):
return voices
def load_voice(voice, extra_voice_dirs=[], load_latents=True):
def load_voice(voice, extra_voice_dirs=[], load_latents=True, sample_rate=22050):
if voice == 'random':
return None, None
@ -125,7 +125,7 @@ def load_voice(voice, extra_voice_dirs=[], load_latents=True):
conds = []
for cond_path in voices:
c = load_audio(cond_path, 22050)
c = load_audio(cond_path, sample_rate)
conds.append(c)
return conds, None
@ -197,8 +197,8 @@ class TacotronSTFT(torch.nn.Module):
return mel_output
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda'):
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda', sample_rate=24000):
stft = TacotronSTFT(1024, 256, 1024, 100, sample_rate, 0, 12000)
stft = stft.to(device)
mel = stft.mel_spectrogram(wav)
if do_normalization: