support latents into the diffusion decoder

This commit is contained in:
James Betker 2022-04-12 20:53:09 -06:00
parent 5988aa34eb
commit 732deaa212
4 changed files with 55 additions and 29 deletions

21
api.py
View File

@ -117,7 +117,7 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_
cond_mels.append(cond_mel) cond_mels.append(cond_mel)
cond_mels = torch.stack(cond_mels, dim=1) cond_mels = torch.stack(cond_mels, dim=1)
output_seq_len = mel_codes.shape[-1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. output_seq_len = mel_codes.shape[1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (mel_codes.shape[0], 100, output_seq_len) output_shape = (mel_codes.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False) precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False)
@ -151,11 +151,6 @@ class TextToSpeech:
layer_drop=0, unconditioned_percentage=0).cpu().eval() layer_drop=0, unconditioned_percentage=0).cpu().eval()
self.diffusion.load_state_dict(torch.load('.models/diffusion.pth')) self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
self.diffusion_next = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
layer_drop=0, unconditioned_percentage=0).cpu().eval()
self.diffusion_next.load_state_dict(torch.load('.models/diffusion_next.pth'))
self.vocoder = UnivNetGenerator().cpu() self.vocoder = UnivNetGenerator().cpu()
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g']) self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
self.vocoder.eval(inference=True) self.vocoder.eval(inference=True)
@ -223,12 +218,22 @@ class TextToSpeech:
self.clip = self.clip.cpu() self.clip = self.clip.cpu()
del samples del samples
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
self.autoregressive = self.autoregressive.cuda()
best_latents = self.autoregressive(conds, text, torch.tensor([text.shape[-1]], device=conds.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device),
return_latent=True, clip_inputs=False)
self.autoregressive = self.autoregressive.cpu()
print("Performing vocoding..") print("Performing vocoding..")
wav_candidates = [] wav_candidates = []
self.diffusion = self.diffusion.cuda() self.diffusion = self.diffusion.cuda()
self.vocoder = self.vocoder.cuda() self.vocoder = self.vocoder.cuda()
for b in range(best_results.shape[0]): for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0) codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0)
# Find the first occurrence of the "calm" token and trim the codes to that. # Find the first occurrence of the "calm" token and trim the codes to that.
ctokens = 0 ctokens = 0
@ -238,10 +243,10 @@ class TextToSpeech:
else: else:
ctokens = 0 ctokens = 0
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
codes = codes[:, :k] latents = latents[:, :k]
break break
mel = do_spectrogram_diffusion(self.diffusion, diffuser, codes, voice_samples, temperature=diffusion_temperature) mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, voice_samples, temperature=diffusion_temperature)
wav = self.vocoder.inference(mel) wav = self.vocoder.inference(mel)
wav_candidates.append(wav.cpu()) wav_candidates.append(wav.cpu())
self.diffusion = self.diffusion.cpu() self.diffusion = self.diffusion.cpu()

View File

@ -7,7 +7,7 @@ from utils.audio import load_audio
if __name__ == '__main__': if __name__ == '__main__':
fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv' fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
outpath = 'D:\\tmp\\tortoise-tts-eval\\diverse_auto_256_samp_100_di_4' outpath = 'D:\\tmp\\tortoise-tts-eval\\diverse_new_decoder_1'
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real' outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
os.makedirs(outpath, exist_ok=True) os.makedirs(outpath, exist_ok=True)

View File

@ -362,7 +362,7 @@ class UnifiedVoice(nn.Module):
mel_input_tokens[b, actual_end:] = self.stop_mel_token mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens return mel_input_tokens
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False): def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
if second_inputs is not None: if second_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
else: else:
@ -374,6 +374,10 @@ class UnifiedVoice(nn.Module):
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
enc = self.final_norm(enc) enc = self.final_norm(enc)
if return_latent:
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
first_logits = enc[:, :first_inputs.shape[1]] first_logits = enc[:, :first_inputs.shape[1]]
first_logits = first_head(first_logits) first_logits = first_head(first_logits)
first_logits = first_logits.permute(0,2,1) first_logits = first_logits.permute(0,2,1)
@ -385,7 +389,8 @@ class UnifiedVoice(nn.Module):
else: else:
return first_logits return first_logits
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False): def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
return_latent=False, clip_inputs=True):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
@ -396,19 +401,23 @@ class UnifiedVoice(nn.Module):
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
raw_mels: MEL float tensor (b,80,s) raw_mels: MEL float tensor (b,80,s)
"""
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by If return_attentions is specified, only logits are returned.
# chopping the inputs by the maximum actual length. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
max_text_len = text_lengths.max() If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) """
max_mel_len = wav_lengths.max() // self.mel_length_compression if clip_inputs:
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
if raw_mels is not None: # chopping the inputs by the maximum actual length.
raw_mels = raw_mels[:, :, :max_mel_len*4] max_text_len = text_lengths.max()
text_inputs = text_inputs[:, :max_text_len]
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = mel_codes[:, :max_mel_len]
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
mel_codes = self.set_mel_padding(mel_codes, wav_lengths) mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = [] conds = []
@ -427,10 +436,15 @@ class UnifiedVoice(nn.Module):
mel_inp = mel_codes mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp) mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
if text_first: if text_first:
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent:
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
else: else:
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions) mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent:
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
if return_attentions: if return_attentions:
return mel_logits return mel_logits

View File

@ -176,7 +176,13 @@ class DiffusionTts(nn.Module):
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
) )
self.code_norm = normalization(model_channels) self.code_norm = normalization(model_channels)
self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1) self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
)
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
@ -190,6 +196,7 @@ class DiffusionTts(nn.Module):
DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads),
DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads),
) )
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1) self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
@ -206,7 +213,7 @@ class DiffusionTts(nn.Module):
groups = { groups = {
'minicoder': list(self.contextual_embedder.parameters()), 'minicoder': list(self.contextual_embedder.parameters()),
'layers': list(self.layers.parameters()), 'layers': list(self.layers.parameters()),
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()), 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()),
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()), 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
'time_embed': list(self.time_embed.parameters()), 'time_embed': list(self.time_embed.parameters()),
} }
@ -227,7 +234,7 @@ class DiffusionTts(nn.Module):
cond_emb = conds.mean(dim=-1) cond_emb = conds.mean(dim=-1)
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1) cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):
code_emb = self.autoregressive_latent_converter(aligned_conditioning) code_emb = self.latent_conditioner(aligned_conditioning)
else: else:
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
code_emb = self.code_converter(code_emb) code_emb = self.code_converter(code_emb)
@ -269,7 +276,7 @@ class DiffusionTts(nn.Module):
if conditioning_free: if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(list(self.latent_converter.parameters())) unused_params.extend(list(self.latent_conditioner.parameters()))
else: else:
if precomputed_aligned_embeddings is not None: if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings code_emb = precomputed_aligned_embeddings
@ -278,7 +285,7 @@ class DiffusionTts(nn.Module):
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
else: else:
unused_params.extend(list(self.latent_converter.parameters())) unused_params.extend(list(self.latent_conditioner.parameters()))
unused_params.append(self.unconditioned_embedding) unused_params.append(self.unconditioned_embedding)