forked from mrq/tortoise-tts
fixed regression where the auto_conds do not move to the GPU and causes a problem during CVVP compare pass
This commit is contained in:
parent
3d69274a46
commit
1b55730e67
|
@ -483,8 +483,11 @@ class TextToSpeech:
|
||||||
auto_conditioning, diffusion_conditioning, auto_conds, _ = conditioning_latents
|
auto_conditioning, diffusion_conditioning, auto_conds, _ = conditioning_latents
|
||||||
else:
|
else:
|
||||||
auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
|
auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
|
||||||
|
|
||||||
auto_conditioning = auto_conditioning.to(self.device)
|
auto_conditioning = auto_conditioning.to(self.device)
|
||||||
diffusion_conditioning = diffusion_conditioning.to(self.device)
|
diffusion_conditioning = diffusion_conditioning.to(self.device)
|
||||||
|
if auto_conds is not None:
|
||||||
|
auto_conds = auto_conds.to(self.device)
|
||||||
|
|
||||||
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
|
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
|
||||||
|
|
||||||
|
@ -539,8 +542,10 @@ class TextToSpeech:
|
||||||
for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
|
for batch in tqdm_override(samples, verbose=verbose, progress=progress, desc=desc):
|
||||||
for i in range(batch.shape[0]):
|
for i in range(batch.shape[0]):
|
||||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||||
|
|
||||||
if cvvp_amount != 1:
|
if cvvp_amount != 1:
|
||||||
clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
|
clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
|
||||||
|
|
||||||
if auto_conds is not None and cvvp_amount > 0:
|
if auto_conds is not None and cvvp_amount > 0:
|
||||||
cvvp_accumulator = 0
|
cvvp_accumulator = 0
|
||||||
for cl in range(auto_conds.shape[1]):
|
for cl in range(auto_conds.shape[1]):
|
||||||
|
|
2
webui.py
2
webui.py
|
@ -265,7 +265,7 @@ def generate(
|
||||||
with open(f'{get_voice_dir()}/{voice}/cond_latents.pth', 'rb') as f:
|
with open(f'{get_voice_dir()}/{voice}/cond_latents.pth', 'rb') as f:
|
||||||
info['latents'] = base64.b64encode(f.read()).decode("ascii")
|
info['latents'] = base64.b64encode(f.read()).decode("ascii")
|
||||||
|
|
||||||
if voicefixer:
|
if args.voice_fixer and voicefixer:
|
||||||
# we could do this on the pieces before they get stiched up anyways to save some compute
|
# we could do this on the pieces before they get stiched up anyways to save some compute
|
||||||
# but the stitching would need to read back from disk, defeating the point of caching the waveform
|
# but the stitching would need to read back from disk, defeating the point of caching the waveform
|
||||||
for path in progress.tqdm(audio_cache, desc="Running voicefix..."):
|
for path in progress.tqdm(audio_cache, desc="Running voicefix..."):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user