From 04454ee63ad685b8267fb47c0150c38db30b6350 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 2 Dec 2021 21:04:36 -0700 Subject: [PATCH] Add evaluation logic for gpt_asr_hf2 --- codes/models/gpt_voice/gpt_asr_hf2.py | 17 +++++++++-------- codes/scripts/audio/asr_eval.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index 993f9a82..abfd086e 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -45,7 +45,9 @@ class MelEncoder(nn.Module): ) def forward(self, x): - return self.encoder(x) + for e in self.encoder: + x = e(x) + return x class GPT2InferenceModel(GPT2PreTrainedModel): @@ -262,22 +264,21 @@ class GptAsrHf2(nn.Module): mel_emb = self.mel_encoder(mel_inputs) assert mel_emb.shape[-1] <= self.max_mel_frames - mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1])) mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) self.inference_model.store_mel_emb(mel_emb) # "fake_inputs" are stand-ins for the MEL frames, which will be injected with the prep_inputs function above. if cond_text is None: - fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device) - fake_inputs[:,-1] = self.NUMBER_SYMBOLS + fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=mel_inputs.device) + fake_inputs[:,-1] = self.START_TOKEN else: cond_used = 10 - fake_inputs = torch.full((mel_inputs.shape[0],self.max_mel_frames+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device) - fake_inputs[:,-1-cond_used] = self.NUMBER_SYMBOLS + fake_inputs = torch.full((mel_emb.shape[0],mel_emb.shape[1]+1+cond_used,), fill_value=1, dtype=torch.long, device=mel_inputs.device) + fake_inputs[:,-1-cond_used] = self.START_TOKEN fake_inputs[:, -cond_used:] = cond_text[:, :cond_used] - gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0, - max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True) + gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_TOKEN, pad_token_id=0, eos_token_id=0, + max_length=self.max_symbols_per_phrase+mel_emb.shape[1], temperature=temperature, num_beams=num_beams, use_cache=True) return gen[:, self.max_mel_frames:] diff --git a/codes/scripts/audio/asr_eval.py b/codes/scripts/audio/asr_eval.py index 96605a17..7aa9976c 100644 --- a/codes/scripts/audio/asr_eval.py +++ b/codes/scripts/audio/asr_eval.py @@ -44,7 +44,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_hf.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_hf2.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt