forked from mrq/tortoise-tts
Update sweep & eval_multiple with new voices
This commit is contained in:
parent
732deaa212
commit
2eb5d4b0cb
13
api.py
13
api.py
|
@ -140,6 +140,13 @@ class TextToSpeech:
|
||||||
average_conditioning_embeddings=True).cpu().eval()
|
average_conditioning_embeddings=True).cpu().eval()
|
||||||
self.autoregressive.load_state_dict(torch.load('.models/autoregressive_diverse.pth'))
|
self.autoregressive.load_state_dict(torch.load('.models/autoregressive_diverse.pth'))
|
||||||
|
|
||||||
|
self.autoregressive_for_latents = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
||||||
|
model_dim=1024,
|
||||||
|
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False,
|
||||||
|
train_solo_embeddings=False,
|
||||||
|
average_conditioning_embeddings=True).cpu().eval()
|
||||||
|
self.autoregressive_for_latents.load_state_dict(torch.load('.models/autoregressive_diverse.pth'))
|
||||||
|
|
||||||
self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
|
self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
|
||||||
text_seq_len=350, text_heads=8,
|
text_seq_len=350, text_heads=8,
|
||||||
num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
|
num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
|
||||||
|
@ -221,11 +228,11 @@ class TextToSpeech:
|
||||||
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
||||||
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
||||||
# results, but will increase memory usage.
|
# results, but will increase memory usage.
|
||||||
self.autoregressive = self.autoregressive.cuda()
|
self.autoregressive_for_latents = self.autoregressive_for_latents.cuda()
|
||||||
best_latents = self.autoregressive(conds, text, torch.tensor([text.shape[-1]], device=conds.device), best_results,
|
best_latents = self.autoregressive_for_latents(conds, text, torch.tensor([text.shape[-1]], device=conds.device), best_results,
|
||||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device),
|
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device),
|
||||||
return_latent=True, clip_inputs=False)
|
return_latent=True, clip_inputs=False)
|
||||||
self.autoregressive = self.autoregressive.cpu()
|
self.autoregressive_for_latents = self.autoregressive_for_latents.cpu()
|
||||||
|
|
||||||
print("Performing vocoding..")
|
print("Performing vocoding..")
|
||||||
wav_candidates = []
|
wav_candidates = []
|
||||||
|
|
|
@ -6,32 +6,35 @@ from api import TextToSpeech
|
||||||
from utils.audio import load_audio
|
from utils.audio import load_audio
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
|
fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
|
||||||
outpath = 'D:\\tmp\\tortoise-tts-eval\\diverse_new_decoder_1'
|
stop_after = 128
|
||||||
|
outpath_base = 'D:\\tmp\\tortoise-tts-eval\\diverse'
|
||||||
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
||||||
|
|
||||||
os.makedirs(outpath, exist_ok=True)
|
|
||||||
os.makedirs(outpath_real, exist_ok=True)
|
os.makedirs(outpath_real, exist_ok=True)
|
||||||
with open(fname, 'r', encoding='utf-8') as f:
|
with open(fname, 'r', encoding='utf-8') as f:
|
||||||
lines = [l.strip().split('\t') for l in f.readlines()]
|
lines = [l.strip().split('\t') for l in f.readlines()]
|
||||||
|
|
||||||
recorder = open(os.path.join(outpath, 'transcript.tsv'), 'w', encoding='utf-8')
|
|
||||||
tts = TextToSpeech()
|
tts = TextToSpeech()
|
||||||
for e, line in enumerate(lines):
|
for k in range(4):
|
||||||
transcript = line[0]
|
outpath = f'{outpath_base}_{k}'
|
||||||
if len(transcript) > 120:
|
os.makedirs(outpath, exist_ok=True)
|
||||||
continue # We need to support this, but cannot yet.
|
recorder = open(os.path.join(outpath, 'transcript.tsv'), 'w', encoding='utf-8')
|
||||||
path = os.path.join(os.path.dirname(fname), line[1])
|
for e, line in enumerate(lines):
|
||||||
cond_audio = load_audio(path, 22050)
|
if e >= stop_after:
|
||||||
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
|
break
|
||||||
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=256, k=1,
|
transcript = line[0]
|
||||||
repetition_penalty=2.0, length_penalty=2, temperature=.5, top_p=.5,
|
path = os.path.join(os.path.dirname(fname), line[1])
|
||||||
diffusion_temperature=.7, cond_free_k=2, diffusion_iterations=100)
|
cond_audio = load_audio(path, 22050)
|
||||||
|
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
|
||||||
|
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=128, k=1,
|
||||||
|
repetition_penalty=2.0, length_penalty=2, temperature=.5, top_p=.5,
|
||||||
|
diffusion_temperature=.7, cond_free_k=2, diffusion_iterations=70)
|
||||||
|
|
||||||
down = torchaudio.functional.resample(sample, 24000, 22050)
|
down = torchaudio.functional.resample(sample, 24000, 22050)
|
||||||
fout_path = os.path.join(outpath, os.path.basename(line[1]))
|
fout_path = os.path.join(outpath, os.path.basename(line[1]))
|
||||||
torchaudio.save(fout_path, down.squeeze(0), 22050)
|
torchaudio.save(fout_path, down.squeeze(0), 22050)
|
||||||
|
|
||||||
recorder.write(f'{transcript}\t{fout_path}\n')
|
recorder.write(f'{transcript}\t{fout_path}\n')
|
||||||
recorder.flush()
|
recorder.flush()
|
||||||
recorder.close()
|
recorder.close()
|
17
read.py
17
read.py
|
@ -30,24 +30,13 @@ if __name__ == '__main__':
|
||||||
# These are voices drawn randomly from the training set. You are free to substitute your own voices in, but testing
|
# These are voices drawn randomly from the training set. You are free to substitute your own voices in, but testing
|
||||||
# has shown that the model does not generalize to new voices very well.
|
# has shown that the model does not generalize to new voices very well.
|
||||||
preselected_cond_voices = {
|
preselected_cond_voices = {
|
||||||
# Male voices
|
'emma_stone': ['voices/emma_stone/1.wav','voices/emma_stone/2.wav','voices/emma_stone/3.wav'],
|
||||||
'dotrice': ['voices/dotrice/1.wav', 'voices/dotrice/2.wav'],
|
'tom_hanks': ['voices/tom_hanks/1.wav','voices/tom_hanks/2.wav','voices/tom_hanks/3.wav'],
|
||||||
'harris': ['voices/harris/1.wav', 'voices/harris/2.wav'],
|
|
||||||
'lescault': ['voices/lescault/1.wav', 'voices/lescault/2.wav'],
|
|
||||||
'otto': ['voices/otto/1.wav', 'voices/otto/2.wav'],
|
|
||||||
'obama': ['voices/obama/1.wav', 'voices/obama/2.wav'],
|
|
||||||
'carlin': ['voices/carlin/1.wav', 'voices/carlin/2.wav'],
|
|
||||||
# Female voices
|
|
||||||
'atkins': ['voices/atkins/1.wav', 'voices/atkins/2.wav'],
|
|
||||||
'grace': ['voices/grace/1.wav', 'voices/grace/2.wav'],
|
|
||||||
'kennard': ['voices/kennard/1.wav', 'voices/kennard/2.wav'],
|
|
||||||
'mol': ['voices/mol/1.wav', 'voices/mol/2.wav'],
|
|
||||||
'lj': ['voices/lj/1.wav', 'voices/lj/2.wav'],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
|
parser.add_argument('-textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt")
|
||||||
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice')
|
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='emma_stone')
|
||||||
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
|
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
|
||||||
parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
|
parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
|
||||||
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/')
|
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/longform/')
|
||||||
|
|
24
sweep.py
24
sweep.py
|
@ -24,19 +24,24 @@ def permutations(args):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
|
fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
|
||||||
outpath_base = 'D:\\tmp\\tortoise-tts-eval\\std_sweep3'
|
stop_after = 128
|
||||||
|
outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep'
|
||||||
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
||||||
|
|
||||||
arg_ranges = {
|
arg_ranges = {
|
||||||
'top_p': [.3,.4,.5,.6],
|
'top_p': [.5, 1],
|
||||||
'temperature': [.5, .6],
|
'temperature': [.5, 1],
|
||||||
|
'diffusion_temperature': [.6, 1],
|
||||||
|
'cond_free_k': [0, 1, 4],
|
||||||
|
'repetition_penalty': [1.0, 2.0]
|
||||||
}
|
}
|
||||||
cfgs = permutations(arg_ranges)
|
cfgs = permutations(arg_ranges)
|
||||||
shuffle(cfgs)
|
shuffle(cfgs)
|
||||||
|
|
||||||
for cfg in cfgs:
|
for cfg in cfgs:
|
||||||
outpath = os.path.join(outpath_base, f'{cfg["top_p"]}_{cfg["temperature"]}')
|
cfg_desc = '_'.join([f'{k}-{v}' for k,v in cfg.items()])
|
||||||
|
outpath = os.path.join(outpath_base, f'{cfg_desc}')
|
||||||
os.makedirs(outpath, exist_ok=True)
|
os.makedirs(outpath, exist_ok=True)
|
||||||
os.makedirs(outpath_real, exist_ok=True)
|
os.makedirs(outpath_real, exist_ok=True)
|
||||||
with open(fname, 'r', encoding='utf-8') as f:
|
with open(fname, 'r', encoding='utf-8') as f:
|
||||||
|
@ -45,15 +50,14 @@ if __name__ == '__main__':
|
||||||
recorder = open(os.path.join(outpath, 'transcript.tsv'), 'w', encoding='utf-8')
|
recorder = open(os.path.join(outpath, 'transcript.tsv'), 'w', encoding='utf-8')
|
||||||
tts = TextToSpeech()
|
tts = TextToSpeech()
|
||||||
for e, line in enumerate(lines):
|
for e, line in enumerate(lines):
|
||||||
|
if e >= stop_after:
|
||||||
|
break
|
||||||
transcript = line[0]
|
transcript = line[0]
|
||||||
if len(transcript) > 120:
|
|
||||||
continue # We need to support this, but cannot yet.
|
|
||||||
path = os.path.join(os.path.dirname(fname), line[1])
|
path = os.path.join(os.path.dirname(fname), line[1])
|
||||||
cond_audio = load_audio(path, 22050)
|
cond_audio = load_audio(path, 22050)
|
||||||
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
|
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
|
||||||
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=256, k=1, diffusion_iterations=200,
|
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=256,
|
||||||
repetition_penalty=2.0, length_penalty=2, temperature=.5, top_p=.5,
|
k=1, diffusion_iterations=70, length_penalty=1.0, **cfg)
|
||||||
diffusion_temperature=.7, cond_free_k=2, **cfg)
|
|
||||||
down = torchaudio.functional.resample(sample, 24000, 22050)
|
down = torchaudio.functional.resample(sample, 24000, 22050)
|
||||||
fout_path = os.path.join(outpath, os.path.basename(line[1]))
|
fout_path = os.path.join(outpath, os.path.basename(line[1]))
|
||||||
torchaudio.save(fout_path, down.squeeze(0), 22050)
|
torchaudio.save(fout_path, down.squeeze(0), 22050)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user