forked from mrq/DL-Art-School
w2v_wrapper: get rid of ctc attention mask
This commit is contained in:
parent
79e8f36d30
commit
e1d71e1bd5
|
@ -20,8 +20,10 @@ class Wav2VecWrapper(nn.Module):
|
|||
"""
|
||||
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
||||
"""
|
||||
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True):
|
||||
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False):
|
||||
super().__init__()
|
||||
self.provide_attention_mask = provide_attention_mask
|
||||
|
||||
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
|
||||
# Perform some surgery to get the model we actually want.
|
||||
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
|
||||
|
@ -55,7 +57,10 @@ class Wav2VecWrapper(nn.Module):
|
|||
unaligned_tokens[b, text_lengths[b]:] = -100
|
||||
|
||||
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||
if self.provide_attention_mask:
|
||||
outputs = self.w2v(input_values=audio_norm.squeeze(1), attention_mask=attention_mask, labels=unaligned_tokens)
|
||||
else:
|
||||
outputs = self.w2v(input_values=audio_norm.squeeze(1), labels=unaligned_tokens)
|
||||
|
||||
if self.output_wer:
|
||||
self.last_pred.append(torch.argmax(outputs.logits, dim=-1))
|
||||
|
|
Loading…
Reference in New Issue
Block a user