From 596a62fe01e077a454023672ce79c94f91d42c92 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 4 Nov 2021 10:09:24 -0600 Subject: [PATCH] Apply fix to gpt_asr_hf and prep it for inference Fix is that we were predicting two characters in advance, not next character --- codes/models/gpt_voice/gpt_asr_hf.py | 188 +++++++++++++++++++++- codes/scripts/audio/asr_eval.py | 8 +- codes/train.py | 2 +- codes/trainer/injectors/base_injectors.py | 5 +- 4 files changed, 192 insertions(+), 11 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr_hf.py b/codes/models/gpt_voice/gpt_asr_hf.py index 41121d09..1ba6217b 100644 --- a/codes/models/gpt_voice/gpt_asr_hf.py +++ b/codes/models/gpt_voice/gpt_asr_hf.py @@ -1,7 +1,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import GPT2Model, GPT2Config +from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions +from transformers.utils.model_parallel_utils import get_device_map, assert_device_map from models.tacotron2.text import symbols from trainer.networks import register_model @@ -46,6 +48,159 @@ class MelEncoder(nn.Module): return self.encoder(x) +class GPT2InferenceModel(GPT2PreTrainedModel): + def __init__(self, config, gpt, text_pos_emb, norm, linear): + super().__init__(config) + self.transformer = gpt + self.text_pos_embedding = text_pos_emb + self.lm_head = nn.Sequential(norm, linear) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.cached_mel_emb = None + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def store_mel_emb(self, mel_emb): + self.cached_mel_emb = mel_emb + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_mel_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Create embedding + mel_len = self.cached_mel_emb.shape[1] + if input_ids.shape[1] != 1: + text_inputs = input_ids[:, mel_len:] + text_emb = self.transformer.get_input_embeddings()(text_inputs) + text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_emb.device)) + if self.cached_mel_emb.shape[0] != text_emb.shape[0]: + mel_emb = self.cached_mel_emb.repeat(text_emb.shape[0], 1, 1) + else: + mel_emb = self.cached_mel_emb + emb = torch.cat([mel_emb, text_emb], dim=1) + else: + emb = self.transformer.get_input_embeddings()(input_ids) + \ + self.text_pos_embedding(torch.tensor(attention_mask.shape[1]-mel_len, device=attention_mask.device)).unsqueeze(0).unsqueeze(0) + + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + class GptAsrHf(nn.Module): NUMBER_SYMBOLS = len(symbols) NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1 @@ -61,17 +216,19 @@ class GptAsrHf(nn.Module): self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) self.mel_pos_embedding = nn.Embedding(self.max_mel_frames, model_dim) seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames - self.gpt = GPT2Model(GPT2Config(vocab_size=self.NUMBER_TEXT_TOKENS, + self.gpt_config = GPT2Config(vocab_size=self.NUMBER_TEXT_TOKENS, n_positions=seq_length, n_ctx=seq_length, n_embd=model_dim, n_layer=layers, n_head=heads, gradient_checkpointing=checkpointing, - use_cache=not checkpointing)) + use_cache=not checkpointing) + self.gpt = GPT2Model(self.gpt_config) self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) + def get_logits(self, mel_inputs, text_targets): # Pad front and back. Pad at front is the "START" token. text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS) @@ -91,9 +248,32 @@ class GptAsrHf(nn.Module): def forward(self, mel_inputs, text_targets): text_logits = self.get_logits(mel_inputs, text_targets) - loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long()) + loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean(), text_logits + def inference(self, mel_inputs, cond_text=None, do_sample=False, temperature=1.0, num_beams=8): + if not hasattr(self, 'inference_model'): + self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head) + + mel_emb = self.mel_encoder(mel_inputs) + assert mel_emb.shape[-1] <= self.max_mel_frames + mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1])) + mel_emb = mel_emb.permute(0,2,1).contiguous() + mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + self.inference_model.store_mel_emb(mel_emb) + + # "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above. + if cond_text is None: + fake_inputs = torch.full((1,self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device) + fake_inputs[:,-1] = self.NUMBER_SYMBOLS + else: + cond_used = 10 + fake_inputs = torch.full((1,self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device) + fake_inputs[:,-1-cond_used] = self.NUMBER_SYMBOLS + fake_inputs[:, -cond_used:] = cond_text[:, :cond_used] + gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0, + max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=False) + return gen[:, self.max_mel_frames:] @register_model def register_gpt_asr_hf(opt_net, opt): diff --git a/codes/scripts/audio/asr_eval.py b/codes/scripts/audio/asr_eval.py index bcb6b917..e459705b 100644 --- a/codes/scripts/audio/asr_eval.py +++ b/codes/scripts/audio/asr_eval.py @@ -41,7 +41,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_mass.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_hf.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt @@ -71,11 +71,11 @@ if __name__ == "__main__": tq = tqdm(test_loader) for data in tq: - if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255: - continue + #if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255: + # continue pred = forward_pass(model, data, dataset_dir, opt, batch) pred = pred.replace('_', '') - output.write(f'{pred}\t{os.path.basename(data["path"][0])}\n') + output.write(f'{pred}\t{os.path.basename(data["filenames"][0])}\n') print(pred) output.flush() batch += 1 diff --git a/codes/train.py b/codes/train.py index 5efd01ec..e9a203e7 100644 --- a/codes/train.py +++ b/codes/train.py @@ -284,7 +284,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 0d275845..1e11bf72 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -40,6 +40,7 @@ class GeneratorInjector(Injector): super(GeneratorInjector, self).__init__(opt, env) self.grad = opt['grad'] if 'grad' in opt.keys() else True self.method = opt_get(opt, ['method'], None) # If specified, this method is called instead of __call__() + self.args = opt_get(opt, ['args'], {}) def forward(self, state): gen = self.env['generators'][self.opt['generator']] @@ -54,10 +55,10 @@ class GeneratorInjector(Injector): else: params = [state[self.input]] if self.grad: - results = method(*params) + results = method(*params, **self.args) else: with torch.no_grad(): - results = method(*params) + results = method(*params, **self.args) new_state = {} if isinstance(self.output, list): # Only dereference tuples or lists, not tensors. IF YOU REACH THIS ERROR, REMOVE THE BRACES AROUND YOUR OUTPUTS IN THE YAML CONFIG