1
1
forked from mrq/tortoise-tts

cleanup loop, save files while generating a batch in the event it crashes midway through

This commit is contained in:
mrq 2023-02-12 01:15:22 +00:00
parent 1b55730e67
commit ddd0c4ccf8

View File

@ -143,12 +143,10 @@ def generate(
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
idx = 0
idx = 1
for i, file in enumerate(os.listdir(outdir)):
if file[-5:] == ".json":
idx = idx + 1
if idx:
idx = idx + 1
# reserve, if for whatever reason you manage to concurrently generate
with open(f'{outdir}/input_{idx}.json', 'w', encoding="utf-8") as f:
@ -180,24 +178,23 @@ def generate(
run_time = time.time()-start_time
print(f"Generating line took {run_time} seconds")
if isinstance(gen, list):
for j, g in enumerate(gen):
name = get_name(line=line, candidate=j)
audio_cache[name] = {
'audio': g,
'text': cut_text,
'time': run_time
}
else:
name = get_name(line=line)
if not isinstance(gen, list):
gen = [gen]
for j, g in enumerate(gen):
audio = g.squeeze(0).cpu()
name = get_name(line=line, candidate=j)
audio_cache[name] = {
'audio': gen,
'audio': audio,
'text': cut_text,
'time': run_time,
'time': run_time
}
# save here in case some error happens mid-batch
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
for k in audio_cache:
audio = audio_cache[k]['audio'].squeeze(0).cpu()
audio = audio_cache[k]['audio']
if resampler is not None:
audio = resampler(audio)
if volume_adjust is not None: