From ba7f54c162e34b8f95b5f81794d57b991f6c23af Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 21 Feb 2022 19:13:03 -0700 Subject: [PATCH] w2v: new inference function --- codes/models/asr/w2v_wrapper.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/codes/models/asr/w2v_wrapper.py b/codes/models/asr/w2v_wrapper.py index 25fa2975..3dd9085b 100644 --- a/codes/models/asr/w2v_wrapper.py +++ b/codes/models/asr/w2v_wrapper.py @@ -129,6 +129,11 @@ class Wav2VecWrapper(nn.Module): pred = logits.argmax(dim=-1) return [self.decode_ctc(p) for p in pred] + def inference_logits(self, audio): + audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + logits = self.w2v(input_values=audio_norm.squeeze(1)).logits + return logits + @register_model def register_wav2vec_feature_extractor(opt_net, opt):