support latents into the diffusion decoder

This commit is contained in:
James Betker 2022-04-12 20:53:09 -06:00
parent e2ee843098
commit 3214ca0dfe
5 changed files with 55 additions and 315 deletions

21
api.py
View File

@ -117,7 +117,7 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_
cond_mels.append(cond_mel) cond_mels.append(cond_mel)
cond_mels = torch.stack(cond_mels, dim=1) cond_mels = torch.stack(cond_mels, dim=1)
output_seq_len = mel_codes.shape[-1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. output_seq_len = mel_codes.shape[1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (mel_codes.shape[0], 100, output_seq_len) output_shape = (mel_codes.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False) precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False)
@ -151,11 +151,6 @@ class TextToSpeech:
layer_drop=0, unconditioned_percentage=0).cpu().eval() layer_drop=0, unconditioned_percentage=0).cpu().eval()
self.diffusion.load_state_dict(torch.load('.models/diffusion.pth')) self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
self.diffusion_next = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
layer_drop=0, unconditioned_percentage=0).cpu().eval()
self.diffusion_next.load_state_dict(torch.load('.models/diffusion_next.pth'))
self.vocoder = UnivNetGenerator().cpu() self.vocoder = UnivNetGenerator().cpu()
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g']) self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
self.vocoder.eval(inference=True) self.vocoder.eval(inference=True)
@ -223,12 +218,22 @@ class TextToSpeech:
self.clip = self.clip.cpu() self.clip = self.clip.cpu()
del samples del samples
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
self.autoregressive = self.autoregressive.cuda()
best_latents = self.autoregressive(conds, text, torch.tensor([text.shape[-1]], device=conds.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device),
return_latent=True, clip_inputs=False)
self.autoregressive = self.autoregressive.cpu()
print("Performing vocoding..") print("Performing vocoding..")
wav_candidates = [] wav_candidates = []
self.diffusion = self.diffusion.cuda() self.diffusion = self.diffusion.cuda()
self.vocoder = self.vocoder.cuda() self.vocoder = self.vocoder.cuda()
for b in range(best_results.shape[0]): for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0) codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0)
# Find the first occurrence of the "calm" token and trim the codes to that. # Find the first occurrence of the "calm" token and trim the codes to that.
ctokens = 0 ctokens = 0
@ -238,10 +243,10 @@ class TextToSpeech:
else: else:
ctokens = 0 ctokens = 0
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
codes = codes[:, :k] latents = latents[:, :k]
break break
mel = do_spectrogram_diffusion(self.diffusion, diffuser, codes, voice_samples, temperature=diffusion_temperature) mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, voice_samples, temperature=diffusion_temperature)
wav = self.vocoder.inference(mel) wav = self.vocoder.inference(mel)
wav_candidates.append(wav.cpu()) wav_candidates.append(wav.cpu())
self.diffusion = self.diffusion.cpu() self.diffusion = self.diffusion.cpu()

View File

@ -7,7 +7,7 @@ from utils.audio import load_audio
if __name__ == '__main__': if __name__ == '__main__':
fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv' fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv'
outpath = 'D:\\tmp\\tortoise-tts-eval\\diverse_auto_256_samp_100_di_4' outpath = 'D:\\tmp\\tortoise-tts-eval\\diverse_new_decoder_1'
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real' outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
os.makedirs(outpath, exist_ok=True) os.makedirs(outpath, exist_ok=True)

View File

@ -362,7 +362,7 @@ class UnifiedVoice(nn.Module):
mel_input_tokens[b, actual_end:] = self.stop_mel_token mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens return mel_input_tokens
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False): def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
if second_inputs is not None: if second_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
else: else:
@ -374,6 +374,10 @@ class UnifiedVoice(nn.Module):
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
enc = self.final_norm(enc) enc = self.final_norm(enc)
if return_latent:
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
first_logits = enc[:, :first_inputs.shape[1]] first_logits = enc[:, :first_inputs.shape[1]]
first_logits = first_head(first_logits) first_logits = first_head(first_logits)
first_logits = first_logits.permute(0,2,1) first_logits = first_logits.permute(0,2,1)
@ -385,7 +389,8 @@ class UnifiedVoice(nn.Module):
else: else:
return first_logits return first_logits
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False): def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
return_latent=False, clip_inputs=True):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
@ -396,19 +401,23 @@ class UnifiedVoice(nn.Module):
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
raw_mels: MEL float tensor (b,80,s) raw_mels: MEL float tensor (b,80,s)
"""
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
"""
if clip_inputs:
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length. # chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max() max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) text_inputs = text_inputs[:, :max_text_len]
max_mel_len = wav_lengths.max() // self.mel_length_compression max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) mel_codes = mel_codes[:, :max_mel_len]
if raw_mels is not None: if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4] raw_mels = raw_mels[:, :, :max_mel_len*4]
mel_codes = self.set_mel_padding(mel_codes, wav_lengths) mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = [] conds = []
@ -427,10 +436,15 @@ class UnifiedVoice(nn.Module):
mel_inp = mel_codes mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp) mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
if text_first: if text_first:
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent:
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
else: else:
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions) mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent:
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
if return_attentions: if return_attentions:
return mel_logits return mel_logits

View File

@ -176,7 +176,13 @@ class DiffusionTts(nn.Module):
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
) )
self.code_norm = normalization(model_channels) self.code_norm = normalization(model_channels)
self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1) self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
)
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
@ -190,6 +196,7 @@ class DiffusionTts(nn.Module):
DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads),
DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads),
) )
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1) self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
@ -206,7 +213,7 @@ class DiffusionTts(nn.Module):
groups = { groups = {
'minicoder': list(self.contextual_embedder.parameters()), 'minicoder': list(self.contextual_embedder.parameters()),
'layers': list(self.layers.parameters()), 'layers': list(self.layers.parameters()),
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()), 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()),
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()), 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
'time_embed': list(self.time_embed.parameters()), 'time_embed': list(self.time_embed.parameters()),
} }
@ -227,7 +234,7 @@ class DiffusionTts(nn.Module):
cond_emb = conds.mean(dim=-1) cond_emb = conds.mean(dim=-1)
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1) cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):
code_emb = self.autoregressive_latent_converter(aligned_conditioning) code_emb = self.latent_conditioner(aligned_conditioning)
else: else:
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
code_emb = self.code_converter(code_emb) code_emb = self.code_converter(code_emb)
@ -269,7 +276,7 @@ class DiffusionTts(nn.Module):
if conditioning_free: if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(list(self.latent_converter.parameters())) unused_params.extend(list(self.latent_conditioner.parameters()))
else: else:
if precomputed_aligned_embeddings is not None: if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings code_emb = precomputed_aligned_embeddings
@ -278,7 +285,7 @@ class DiffusionTts(nn.Module):
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
else: else:
unused_params.extend(list(self.latent_converter.parameters())) unused_params.extend(list(self.latent_conditioner.parameters()))
unused_params.append(self.unconditioned_embedding) unused_params.append(self.unconditioned_embedding)

View File

@ -1,286 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2PreTrainedModel, GPT2Config
from models.xtransformers import TransformerWrapper, Encoder, Decoder
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from models.arch_util import AttentionBlock
class InferenceModel(GPT2PreTrainedModel):
"""
Implementation of GPT2PreTrainedModel from transformers, which allows us to use their generation library with
this transformer.
"""
def __init__(self, model):
super().__init__(GPT2Config())
self.transformer = model
self.context = None
def parallelize(self, device_map=None):
# Not implemented.
pass
def deparallelize(self):
# Not implemented.
pass
def get_output_embeddings(self):
assert False, "Unsupported operation."
def set_output_embeddings(self, new_embeddings):
assert False, "Unsupported operation."
def store_context(self, context):
self.context = context
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert self.context is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values,
use_cache=use_cache, expected_seq_len=100)
if use_cache:
hidden_states, present_key_values = out
else:
hidden_states = out
present_key_values = None
logits = self.transformer.decoder.to_logits(hidden_states)
if not return_dict:
return (logits, )
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=logits,
past_key_values=present_key_values,
hidden_states=hidden_states,
attentions=None,
cross_attentions=None,
)
@staticmethod
def _reorder_cache(past, beam_idx):
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan)
)
def forward(self, x):
return F.relu(self.net(x) + x)
class ConditioningEncoder(nn.Module):
def __init__(self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False):
super().__init__()
attn = []
self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//4, kernel_size=5, padding=2),
nn.Conv1d(embedding_dim//4, embedding_dim//2, kernel_size=3, padding=1, stride=2),
ResBlock(embedding_dim//2),
nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2))
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
def forward(self, x):
h = self.init(x)
h = self.attn(h)
return h.mean(dim=2)
class AutoregressiveCodegen(nn.Module):
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
super().__init__()
assert depth >= 8 # This is the minimum bound to support the context interleaving that happens later.
self.START_TOKEN=8192
self.STOP_TOKEN=8193
self.START_TEXT_TOKEN = 255
self.STOP_TEXT_TOKEN = 0
self.max_text_token_id = num_text_tokens
self.max_mel_token_id = num_mel_tokens
self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
self.encoder = TransformerWrapper(
num_tokens=num_text_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers = Encoder(
depth=depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=1,
rotary_pos_emb=True,
attn_rel_pos_bias=True,
))
self.encoder.norm = nn.Identity() # This layer and the next are unused.
self.encoder.to_logits = nn.Identity()
self.decoder = TransformerWrapper(
num_tokens=num_mel_tokens,
use_pos_emb=False,
max_seq_len=-1,
attn_layers=Decoder(
depth=depth,
heads=model_dim//64,
dim=model_dim,
attn_dropout=dropout,
ff_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
ff_mult=1,
rotary_pos_emb=True,
cross_attend=True,
attn_rel_pos_bias=True,
))
def get_grad_norm_parameter_groups(self):
return {
'encoder': list(self.encoder.parameters()),
'decoder': list(self.decoder.parameters()),
'minicoder': list(self.mel_embedding.parameters()),
}
def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True):
assert text_codes.max() < self.max_text_token_id and text_codes.min() >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}'
assert mel_codes.max() < self.max_mel_token_id and mel_codes.min() >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}'
# Format mel_codes with a stop token on the end.
mel_lengths = wav_lengths // 1024 + 1
for b in range(mel_codes.shape[0]):
mel_codes[b, mel_lengths[b]:] = self.STOP_TOKEN
mel_codes = F.pad(mel_codes, (0, 1), value=self.STOP_TOKEN)
# Build the context
if len(conditioning_signal.shape) != 4:
conditioning_signal = conditioning_signal.unsqueeze(1)
cond_embs = []
for i in range(conditioning_signal.shape[1]):
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
# Since all positional embeddings are relative, it is (probably) important to "fix" the text with some permanent embeddings.
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
_, enc_text = self.encoder(text_codes, return_hiddens=True)
# Interleave cond_emb into the first few contexts.
full_context = enc_text
full_context[1] = cond_emb
full_context[3] = cond_emb
full_context[6] = cond_emb
# Execute the decoder
dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
dec = self.decoder(dec_inputs, full_context=full_context)
if not return_loss:
return dec
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
return loss_mel
def generate(self, conditioning_signal, text_codes, max_tokens=256, **hf_generate_kwargs):
inference_model = InferenceModel(self)
# Build the context
if len(conditioning_signal.shape) != 4:
conditioning_signal = conditioning_signal.unsqueeze(1)
cond_embs = []
for i in range(conditioning_signal.shape[1]):
cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
text_codes = F.pad(text_codes, (1,0), value=self.START_TEXT_TOKEN)
text_codes = F.pad(text_codes, (0,1), value=self.STOP_TEXT_TOKEN)
_, enc_text = self.encoder(text_codes, return_hiddens=True)
# Interleave cond_emb into the first few contexts.
full_context = enc_text
full_context[1] = cond_emb
full_context[3] = cond_emb
full_context[6] = cond_emb
inference_model.store_context(full_context)
gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, use_cache=False,
**hf_generate_kwargs)
return gen.sequences
if __name__ == '__main__':
codegen = AutoregressiveCodegen(256, 10)
torch.save(codegen.state_dict(), 'sample.pth')
#codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
codegen(torch.randint(0,256, (2,200)),
torch.randn(2,80,120),
torch.randint(0,8192, (2,350)),
torch.tensor([192,350]))