diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index d53e595a..134147cd 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -32,17 +32,17 @@ class ResBlock(nn.Module): class GPT2InferenceModel(GPT2PreTrainedModel): - def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear): + def __init__(self, config, gpt, posterior_pos_emb, embeddings, norm, linear): super().__init__(config) self.transformer = gpt - self.text_pos_embedding = text_pos_emb + self.posterior_pos_embedding = posterior_pos_emb self.embeddings = embeddings - self.lm_head = nn.Sequential(norm, linear) + self.head = nn.Sequential(norm, linear) # Model parallel self.model_parallel = False self.device_map = None - self.cached_mel_emb = None + self.cached_prior_emb = None def parallelize(self, device_map=None): self.device_map = ( @@ -52,27 +52,26 @@ class GPT2InferenceModel(GPT2PreTrainedModel): ) assert_device_map(self.device_map, len(self.transformer.h)) self.transformer.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.transformer.first_device) + self.head = self.head.to(self.transformer.first_device) self.model_parallel = True def deparallelize(self): self.transformer.deparallelize() self.transformer = self.transformer.to("cpu") - self.lm_head = self.lm_head.to("cpu") + self.head = self.head.to("cpu") self.model_parallel = False torch.cuda.empty_cache() def get_output_embeddings(self): - return self.lm_head + return self.head def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings + self.head = new_embeddings - def store_mel_emb(self, mel_emb): - self.cached_mel_emb = mel_emb + def store_prior_emb(self, mel_emb): + self.cached_prior_emb = mel_emb def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past: @@ -117,25 +116,25 @@ class GPT2InferenceModel(GPT2PreTrainedModel): output_hidden_states=None, return_dict=None, ): - assert self.cached_mel_emb is not None + assert self.cached_prior_emb is not None assert inputs_embeds is None # Not supported by this inference model. assert labels is None # Training not supported by this inference model. return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Create embedding - mel_len = self.cached_mel_emb.shape[1] + prior_len = self.cached_prior_emb.shape[1] if input_ids.shape[1] != 1: - text_inputs = input_ids[:, mel_len:] - text_emb = self.embeddings(text_inputs) - text_emb = text_emb + self.text_pos_embedding(text_emb) - if self.cached_mel_emb.shape[0] != text_emb.shape[0]: - mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0) + posterior_inputs = input_ids[:, prior_len:] + posterior_emb = self.embeddings(posterior_inputs) + posterior_emb = posterior_emb + self.posterior_pos_embedding(posterior_emb) + if self.cached_prior_emb.shape[0] != posterior_emb.shape[0]: + prior_emb = self.cached_prior_emb.repeat_interleave(posterior_emb.shape[0] // self.cached_prior_emb.shape[0], 0) else: - mel_emb = self.cached_mel_emb - emb = torch.cat([mel_emb, text_emb], dim=1) + prior_emb = self.cached_prior_emb + emb = torch.cat([prior_emb, posterior_emb], dim=1) else: emb = self.embeddings(input_ids) - emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device) + emb = emb + self.posterior_pos_embedding.get_fixed_embedding(attention_mask.shape[1] - prior_len, attention_mask.device) transformer_outputs = self.transformer( inputs_embeds=emb, @@ -156,16 +155,16 @@ class GPT2InferenceModel(GPT2PreTrainedModel): # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.transformer.first_device) - hidden_states = hidden_states.to(self.lm_head.weight.device) + hidden_states = hidden_states.to(self.head.weight.device) - lm_logits = self.lm_head(hidden_states) + logits = self.head(hidden_states) if not return_dict: - return (lm_logits,) + transformer_outputs[1:] + return (logits,) + transformer_outputs[1:] return CausalLMOutputWithCrossAttentions( loss=None, - logits=lm_logits, + logits=logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, @@ -239,9 +238,7 @@ class MelEncoder(nn.Module): class UnifiedVoice(nn.Module): def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, mel_length_compression=1024, number_text_tokens=256, number_mel_codes=8194, start_mel_token=8192, - stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, start_text_token=None, - checkpointing=True, average_conditioning_embeddings=False, freeze_everything_but_position_embeddings=False, - types=1): + stop_mel_token=8193, start_text_token=None, checkpointing=True, types=1): """ Args: layers: Number of layers in transformer stack. @@ -252,14 +249,10 @@ class UnifiedVoice(nn.Module): max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. number_text_tokens: - stop_text_token: number_mel_codes: start_mel_token: stop_mel_token: - train_solo_embeddings: - use_mel_codes_as_input: checkpointing: - average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model. """ super().__init__() @@ -277,41 +270,20 @@ class UnifiedVoice(nn.Module): self.model_dim = model_dim self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) - self.average_conditioning_embeddings = average_conditioning_embeddings self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) - if use_mel_codes_as_input: - self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) - else: - self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) + self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens, self.max_text_tokens, checkpointing) - if train_solo_embeddings: - self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) - self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) - else: - self.mel_solo_embedding = 0 - self.text_solo_embedding = 0 self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) self.mel_head = nn.Linear(model_dim, self.number_mel_codes) # Initialize the embeddings per the GPT-2 scheme - embeddings = [self.text_embedding] - if use_mel_codes_as_input: - embeddings.append(self.mel_embedding) + embeddings = [self.text_embedding, self.mel_embedding] for module in embeddings: module.weight.data.normal_(mean=0.0, std=.02) - if freeze_everything_but_position_embeddings: - for p in self.parameters(): - p.requires_grad = False - p.DO_NOT_TRAIN = True - for m in [self.mel_pos_embedding, self.text_pos_embedding]: - for p in m.parameters(): - del p.DO_NOT_TRAIN - p.requires_grad = True - def get_grad_norm_parameter_groups(self): return { 'conditioning_encoder': list(self.conditioning_encoder.parameters()), @@ -338,15 +310,13 @@ class UnifiedVoice(nn.Module): mel_input_tokens[b, actual_end:] = self.stop_mel_token return mel_input_tokens - def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False): + def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, return_latent=False): if second_inputs is not None: emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) else: emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) - gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) - if get_attns: - return gpt_out.attentions + gpt_out = self.gpt(inputs_embeds=emb, return_dict=True) enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input enc = self.final_norm(enc) @@ -372,13 +342,11 @@ class UnifiedVoice(nn.Module): for j in range(speech_conditioning_input.shape[1]): conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) - if self.average_conditioning_embeddings: - conds = conds.mean(dim=1).unsqueeze(1) + conds = conds.mean(dim=1).unsqueeze(1) return conds - def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False, - return_latent=False, clip_inputs=True): + def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, return_latent=False): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode (actuated by `text_first`). @@ -388,25 +356,13 @@ class UnifiedVoice(nn.Module): text_lengths: long tensor, (b,) mel_inputs: long tensor, (b,m) wav_lengths: long tensor, (b,) - raw_mels: MEL float tensor (b,80,s) - If return_attentions is specified, only logits are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. - If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. """ # Types are expressed by expanding the text embedding space. if types is not None: text_inputs = text_inputs * (1+types).unsqueeze(-1) - if clip_inputs: - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_text_len = text_lengths.max() - text_inputs = text_inputs[:, :max_text_len] - max_mel_len = wav_lengths.max() // self.mel_length_compression - mel_codes = mel_codes[:, :max_mel_len] - if raw_mels is not None: - raw_mels = raw_mels[:, :, :max_mel_len*4] mel_codes = self.set_mel_padding(mel_codes, wav_lengths) text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) @@ -416,86 +372,24 @@ class UnifiedVoice(nn.Module): text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) - if raw_mels is not None: - mel_inp = F.pad(raw_mels, (0, 8)) - else: - mel_inp = mel_codes + mel_inp = mel_codes mel_emb = self.mel_embedding(mel_inp) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) if text_first: - text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) + text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, return_latent=return_latent) if return_latent: return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. else: - mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) + mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, return_latent=return_latent) if return_latent: return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. - if return_attentions: - return mel_logits loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_text.mean(), loss_mel.mean(), mel_logits - def text_forward(self, speech_conditioning_input, text_inputs, text_lengths): - """ - Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the - model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). - """ - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_text_len = text_lengths.max() - text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) - - speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - if self.average_conditioning_embeddings: - conds = conds.mean(dim=1).unsqueeze(1) - - text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding - text_logits = self.get_logits(conds, text_emb, self.text_head) - loss_text = F.cross_entropy(text_logits, text_targets.long()) - return loss_text.mean() - - def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None): - """ - Performs autoregressive modeling on only speech data. - """ - assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}' - - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_mel_len = wav_lengths.max() // self.mel_length_compression - mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) - mel_codes = self.set_mel_padding(mel_codes, wav_lengths) - if raw_mels is not None: - raw_mels = raw_mels[:, :, :max_mel_len*4] - - speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - if self.average_conditioning_embeddings: - conds = conds.mean(dim=1).unsqueeze(1) - - mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) - if raw_mels is not None: - mel_inp = F.pad(raw_mels, (0, 4)) - else: - mel_inp = mel_codes - mel_emb = self.mel_embedding(mel_inp) - mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding - mel_logits = self.get_logits(conds, mel_emb, self.mel_head) - loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) - return loss_mel.mean() - - def inference_speech(self, speech_conditioning_input, text_inputs, return_attentions=False, **hf_generate_kwargs): + def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs): if self.max_mel_tokens == -1: # Assume if this is the case, max_mel_tokens=-1 also seq_length = 2002 # Arbitrary default. else: @@ -522,64 +416,17 @@ class UnifiedVoice(nn.Module): for j in range(speech_conditioning_input.shape[1]): conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) - if self.average_conditioning_embeddings: - conds = conds.mean(dim=1).unsqueeze(1) + conds = conds.mean(dim=1).unsqueeze(1) emb = torch.cat([conds, text_emb], dim=1) - self.inference_model.store_mel_emb(emb) + self.inference_model.store_prior_emb(emb) fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device) fake_inputs[:,-1] = self.start_mel_token gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, - max_length=seq_length, output_attentions=return_attentions, return_dict_in_generate=True, **hf_generate_kwargs) - if return_attentions: - return gen.sequences[:, fake_inputs.shape[1]:], gen.attentions - else: - return gen.sequences[:, fake_inputs.shape[1]:] - - - # Turns the (utterly insane) output of HF.generate() into a far more sane output: - # [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence - def make_hf_generate_attentions_sane(self, attentions): - layers = [[] for _ in range(len(attentions[0]))] - full_attention_size = attentions[-1][0].shape[-1] - for i, gen in enumerate(attentions): - for j, lyr in enumerate(gen): - layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1]))) - catted = [] - for lyr in layers: - catted.append(torch.cat(lyr, dim=2)) - return catted - - def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds): - """ - This was an attempt to make some sense out of the attention matrix retrieved from the unified_voice model. Unfortunately, I can't use it for aligning text & voice. - """ - text_padding = num_conds+2 - num_text = text.shape[-1] - num_context = num_text + text_padding - assert num_context + 1 == attentions[0][0].shape[-1] - attentions = self.make_hf_generate_attentions_sane(attentions) - results = [torch.empty_like(codes) for _ in range(len(attentions))] - for l, layer in enumerate(attentions): - dec_context = layer[:, :, num_context:, :] - # Mask out everything that isn't text (including the start token, which gets a LOT of attention) - dec_context[:,:,:,:text_padding+1] = 0 - dec_context[:,:,:,num_context:] = 0 - for h in range(dec_context.shape[1]): - dec_context_indices = torch.argmax(dec_context[0,h], dim=-1) - print(f'layer_{l};head_{h}: ' + str(dec_context_indices)) - for t, att_tok in enumerate(attentions): - combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device) - for lyr in att_tok: - token_to_text_attentions = lyr[:, :, -1, text_padding:(text_padding + num_text)].sum(dim=1) - combined_attention_weights = combined_attention_weights + token_to_text_attentions - break - most_attended_text_token = combined_attention_weights.argmax(dim=-1) - results[:, t] = most_attended_text_token - eos_token_mask = (codes != self.stop_mel_token) - return results * eos_token_mask + max_length=seq_length, return_dict_in_generate=True, **hf_generate_kwargs) + return gen.sequences[:, fake_inputs.shape[1]:] @register_model @@ -588,11 +435,10 @@ def register_unified_voice2(opt_net, opt): if __name__ == '__main__': - gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4, freeze_everything_but_position_embeddings=True, types=2) + gpt = UnifiedVoice(model_dim=256, heads=4, max_conditioning_inputs=4, types=2) l = gpt(torch.randn(2, 3, 80, 800), torch.randint(high=256, size=(2,120)), torch.tensor([32, 120]), torch.randint(high=8192, size=(2,250)), torch.tensor([250*256,195*256]), types=torch.tensor([0, 1])) - #gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))