w2v fine tuner
This commit is contained in:
parent
0c3cc5ebad
commit
29534180b2
0
codes/models/asr/__init__.py
Normal file
0
codes/models/asr/__init__.py
Normal file
108
codes/models/asr/w2v_wrapper.py
Normal file
108
codes/models/asr/w2v_wrapper.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
from itertools import groupby
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from models.tacotron2.text import symbols, sequence_to_text
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class Wav2VecWrapper(nn.Module):
|
||||
"""
|
||||
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
||||
"""
|
||||
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True):
|
||||
super().__init__()
|
||||
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
|
||||
# Perform some surgery to get the model we actually want.
|
||||
self.w2v.wav2vec2.encoder.gradient_checkpointing = True
|
||||
self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size)
|
||||
self.w2v.config.vocab_size = vocab_size
|
||||
self.w2v.config.pad_token_id = 0
|
||||
self.w2v.config.ctc_loss_reduction = 'mean'
|
||||
self.w2v.config.apply_spec_augment = True
|
||||
|
||||
# We always freeze the feature extractor, which needs some special operations in DLAS
|
||||
for p in self.w2v.wav2vec2.feature_extractor.parameters():
|
||||
p.requires_grad = False
|
||||
p.DO_NOT_TRAIN = True
|
||||
if freeze_transformer:
|
||||
# Also freeze the encoder here.
|
||||
for p in list(self.w2v.wav2vec2.encoder.parameters()) + list(self.w2v.wav2vec2.feature_projection.parameters()):
|
||||
p.requires_grad = False
|
||||
p.DO_NOT_TRAIN = True
|
||||
|
||||
self.output_wer = output_wer
|
||||
if output_wer:
|
||||
self.last_pred = []
|
||||
self.last_labels = []
|
||||
|
||||
def forward(self, audio, unaligned_tokens, wav_lengths, text_lengths):
|
||||
audio = audio[:, :, :wav_lengths.max()]
|
||||
unaligned_tokens = unaligned_tokens[:, :text_lengths.max()]
|
||||
attention_mask = torch.ones_like(audio).squeeze(1)
|
||||
for b in range(audio.shape[0]):
|
||||
attention_mask[b, wav_lengths[b]:] = 0
|
||||
unaligned_tokens[b, text_lengths[b]:] = -100
|
||||
|
||||
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||
outputs = self.w2v(input_values=audio_norm.squeeze(1), attention_mask=attention_mask, labels=unaligned_tokens)
|
||||
|
||||
if self.output_wer:
|
||||
self.last_pred.append(torch.argmax(outputs.logits, dim=-1))
|
||||
if len(self.last_pred) > 10:
|
||||
self.last_pred = self.last_pred[1:]
|
||||
self.last_labels.append(unaligned_tokens)
|
||||
if len(self.last_labels) > 10:
|
||||
self.last_labels = self.last_labels[1:]
|
||||
return outputs.loss
|
||||
|
||||
def decode_ctc(self, output):
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.tolist()
|
||||
tokens = [token_group[0] for token_group in groupby(output)]
|
||||
filtered_tokens = list(filter(lambda token: token != 0, tokens))
|
||||
return filtered_tokens
|
||||
|
||||
def get_debug_values(self, step, net_name):
|
||||
res = {}
|
||||
if self.output_wer and step % 100 == 0:
|
||||
from datasets import load_metric
|
||||
wer_metric = load_metric("wer")
|
||||
label_strings = []
|
||||
pred_strings = []
|
||||
for last_labels, last_pred in zip(self.last_labels, self.last_pred):
|
||||
last_labels[last_labels == -100] = 0
|
||||
label_strings.extend([sequence_to_text(lbl) for lbl in last_labels])
|
||||
pred_strings.extend([sequence_to_text(self.decode_ctc(pred)) for pred in last_pred])
|
||||
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
|
||||
res['wer'] = wer
|
||||
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
|
||||
return res
|
||||
|
||||
def inference(self, audio):
|
||||
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||
logits = self.w2v(input_values=audio_norm.squeeze(1)).logits
|
||||
pred = logits.argmax(dim=-1)
|
||||
return self.decode_ctc(pred)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_wav2vec2_finetune(opt_net, opt):
|
||||
return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True)
|
||||
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]))
|
||||
w2v.get_debug_values(0,"")
|
||||
|
||||
sd = torch.load('../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth')
|
||||
w2v.load_state_dict(sd)
|
||||
pred = w2v.inference(load_audio('Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', 16000).unsqueeze(0))
|
||||
res = sequence_to_text(pred[0])
|
||||
print(res)
|
|
@ -316,7 +316,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_mass.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user