forked from mrq/tortoise-tts
Add in ASR filtration
This commit is contained in:
parent
79c74c1484
commit
adccaa44bc
48
do_tts.py
48
do_tts.py
|
@ -7,6 +7,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import progressbar
|
||||
import ocotillo
|
||||
|
||||
from models.diffusion_decoder import DiffusionTts
|
||||
from models.autoregressive import UnifiedVoice
|
||||
|
@ -17,7 +18,7 @@ from models.text_voice_clip import VoiceCLIP
|
|||
from models.vocoder import UnivNetGenerator
|
||||
from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
|
||||
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
||||
from utils.tokenizer import VoiceBpeTokenizer
|
||||
from utils.tokenizer import VoiceBpeTokenizer, lev_distance
|
||||
|
||||
pbar = None
|
||||
def download_models():
|
||||
|
@ -47,13 +48,13 @@ def download_models():
|
|||
print('Done.')
|
||||
|
||||
|
||||
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
|
||||
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True):
|
||||
"""
|
||||
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
||||
"""
|
||||
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
|
||||
conditioning_free=True, conditioning_free_k=1)
|
||||
conditioning_free=cond_free, conditioning_free_k=1)
|
||||
|
||||
|
||||
def load_conditioning(path, sample_rate=22050, cond_length=132300):
|
||||
|
@ -109,11 +110,12 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_
|
|||
mel = torch.nn.functional.pad(mel_codes, (0, gap))
|
||||
|
||||
output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
|
||||
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel)
|
||||
if mean:
|
||||
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
|
||||
model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
|
||||
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
||||
else:
|
||||
mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
|
||||
mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
||||
return denormalize_tacotron_mel(mel)[:,:,:msl*4]
|
||||
|
||||
|
||||
|
@ -136,9 +138,9 @@ if __name__ == '__main__':
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
||||
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
|
||||
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=512)
|
||||
parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=16)
|
||||
parser.add_argument('-num_outputs', type=int, help='Number of outputs to produce.', default=2)
|
||||
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=1024)
|
||||
parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=32)
|
||||
parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16)
|
||||
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -192,7 +194,7 @@ if __name__ == '__main__':
|
|||
return_loss=False))
|
||||
clip_results = torch.cat(clip_results, dim=0)
|
||||
samples = torch.cat(samples, dim=0)
|
||||
best_results = samples[torch.topk(clip_results, k=args.num_outputs).indices]
|
||||
best_results = samples[torch.topk(clip_results, k=args.num_diffusion_samples).indices]
|
||||
|
||||
# Delete the autoregressive and clip models to free up GPU memory
|
||||
del samples, clip
|
||||
|
@ -210,12 +212,32 @@ if __name__ == '__main__':
|
|||
vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
|
||||
vocoder = vocoder.cuda()
|
||||
vocoder.eval(inference=True)
|
||||
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
|
||||
initial_diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=40, cond_free=False)
|
||||
final_diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=500)
|
||||
|
||||
print("Performing vocoding..")
|
||||
# Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive.
|
||||
wav_candidates = []
|
||||
for b in range(best_results.shape[0]):
|
||||
code = best_results[b].unsqueeze(0)
|
||||
mel = do_spectrogram_diffusion(diffusion, diffuser, code, cond_diffusion, mean=False)
|
||||
mel = do_spectrogram_diffusion(diffusion, initial_diffuser, code, cond_diffusion, mean=False)
|
||||
wav = vocoder.inference(mel)
|
||||
torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 24000)
|
||||
wav_candidates.append(wav.cpu())
|
||||
|
||||
# Further refine the remaining candidates using a ASR model to pick out the ones that are the most understandable.
|
||||
transcriber = ocotillo.Transcriber(on_cuda=True)
|
||||
transcriptions = transcriber.transcribe_batch(torch.cat(wav_candidates, dim=0).squeeze(1), 24000)
|
||||
best = 99999999
|
||||
for i, transcription in enumerate(transcriptions):
|
||||
dist = lev_distance(transcription, args.text.lower())
|
||||
if dist < best:
|
||||
best = dist
|
||||
best_codes = best_results[i].unsqueeze(0)
|
||||
best_wav = wav_candidates[i]
|
||||
del transcriber
|
||||
torchaudio.save(os.path.join(args.output_path, f'{voice}_poor.wav'), best_wav.squeeze(0).cpu(), 24000)
|
||||
|
||||
# Perform diffusion again with the high-quality diffuser.
|
||||
mel = do_spectrogram_diffusion(diffusion, final_diffuser, best_codes, cond_diffusion, mean=False)
|
||||
wav = vocoder.inference(mel)
|
||||
torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), wav.squeeze(0).cpu(), 24000)
|
||||
|
||||
|
|
|
@ -486,66 +486,40 @@ class DiffusionTts(nn.Module):
|
|||
aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
|
||||
return x, aligned_conditioning
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
||||
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
|
||||
:param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate.
|
||||
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert conditioning_input is not None
|
||||
if self.super_sampling_enabled:
|
||||
assert lr_input is not None
|
||||
if self.training and self.super_sampling_max_noising_factor > 0:
|
||||
noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
|
||||
lr_input = torch.randn_like(lr_input) * noising_factor + lr_input
|
||||
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
|
||||
x = torch.cat([x, lr_input], dim=1)
|
||||
|
||||
def timestep_independent(self, aligned_conditioning, conditioning_input):
|
||||
# Shuffle aligned_latent to BxCxS format
|
||||
if is_latent(aligned_conditioning):
|
||||
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
||||
|
||||
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
|
||||
orig_x_shape = x.shape[-1]
|
||||
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
|
||||
with autocast(aligned_conditioning.device.type, enabled=self.enable_fp16):
|
||||
cond_emb = self.contextual_embedder(conditioning_input)
|
||||
if len(cond_emb.shape) == 3: # Just take the first element.
|
||||
cond_emb = cond_emb[:, :, 0]
|
||||
if is_latent(aligned_conditioning):
|
||||
code_emb = self.latent_converter(aligned_conditioning)
|
||||
else:
|
||||
code_emb = self.code_converter(aligned_conditioning)
|
||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
|
||||
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
|
||||
return code_emb
|
||||
|
||||
def forward(self, x, timesteps, precomputed_aligned_embeddings, conditioning_free=False):
|
||||
assert x.shape[-1] % self.alignment_size == 0
|
||||
|
||||
with autocast(x.device.type, enabled=self.enable_fp16):
|
||||
hs = []
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||
else:
|
||||
cond_emb = self.contextual_embedder(conditioning_input)
|
||||
if len(cond_emb.shape) == 3: # Just take the first element.
|
||||
cond_emb = cond_emb[:, :, 0]
|
||||
if is_latent(aligned_conditioning):
|
||||
code_emb = self.latent_converter(aligned_conditioning)
|
||||
else:
|
||||
code_emb = self.code_converter(aligned_conditioning)
|
||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
|
||||
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
|
||||
code_emb)
|
||||
code_emb = precomputed_aligned_embeddings
|
||||
|
||||
# Everything after this comment is timestep dependent.
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
|
||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||
|
||||
first = True
|
||||
time_emb = time_emb.float()
|
||||
h = x
|
||||
hs = []
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
if isinstance(module, nn.Conv1d):
|
||||
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
||||
|
@ -565,14 +539,7 @@ class DiffusionTts(nn.Module):
|
|||
h = h.float()
|
||||
out = self.out(h)
|
||||
|
||||
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||
extraneous_addition = 0
|
||||
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
|
||||
for p in params:
|
||||
extraneous_addition = extraneous_addition + p.mean()
|
||||
out = out + extraneous_addition * 0
|
||||
|
||||
return out[:, :, :orig_x_shape]
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -7,4 +7,5 @@ inflect
|
|||
progressbar
|
||||
einops
|
||||
unidecode
|
||||
x-transformers
|
||||
x-transformers
|
||||
ocotillo
|
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
|
@ -148,6 +148,20 @@ def english_cleaners(text):
|
|||
text = text.replace('"', '')
|
||||
return text
|
||||
|
||||
def lev_distance(s1, s2):
|
||||
if len(s1) > len(s2):
|
||||
s1, s2 = s2, s1
|
||||
|
||||
distances = range(len(s1) + 1)
|
||||
for i2, c2 in enumerate(s2):
|
||||
distances_ = [i2 + 1]
|
||||
for i1, c1 in enumerate(s1):
|
||||
if c1 == c2:
|
||||
distances_.append(distances[i1])
|
||||
else:
|
||||
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
|
||||
distances = distances_
|
||||
return distances[-1]
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file='data/tokenizer.json'):
|
||||
|
|
Loading…
Reference in New Issue
Block a user