DL-Art-School/codes/models/asr/w2v_wrapper.py

156 lines
6.7 KiB
Python
Raw Normal View History

2022-02-13 03:00:59 +00:00
from itertools import groupby
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
from data.audio.unsupervised_audio_dataset import load_audio
from models.tacotron2.text import symbols, sequence_to_text
from trainer.networks import register_model
from utils.util import opt_get
2022-02-14 03:47:29 +00:00
def only_letters(string):
allowlist = set(' ABCDEFGHIJKLMNOPQRSTUVWXYZ\'')
return ''.join(filter(allowlist.__contains__, string.upper()))
class Wav2VecFeatureExtractor(nn.Module):
"""
Basic wrapper that only does feature extraction. Useful to build out this portion of the model so it can be
operated through DDP.
"""
def __init__(self, basis_model='facebook/wav2vec2-large'):
super().__init__()
w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
self.extractor = w2v.wav2vec2.feature_extractor
for p in self.extractor.parameters():
p.requires_grad = False
p.DO_NOT_TRAIN = True
def forward(self, audio, wav_lengths):
with torch.no_grad():
audio = audio[:, :, :wav_lengths.max()]
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
return self.extractor(audio_norm.squeeze(1))
2022-02-13 03:00:59 +00:00
class Wav2VecWrapper(nn.Module):
"""
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
"""
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
remove_feature_extractor=False):
2022-02-13 03:00:59 +00:00
super().__init__()
self.provide_attention_mask = provide_attention_mask
2022-02-13 03:00:59 +00:00
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
# Perform some surgery to get the model we actually want.
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
2022-02-13 03:00:59 +00:00
self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size)
self.w2v.config.vocab_size = vocab_size
self.w2v.config.pad_token_id = 0
2022-02-18 05:00:58 +00:00
self.w2v.config.ctc_loss_reduction = 'sum'
2022-02-18 03:22:05 +00:00
self.w2v.config.apply_spec_augment = spec_augment
self.remove_feature_extractor = remove_feature_extractor
if remove_feature_extractor:
# The values passed in to the w2v model in this case are the outputs of the feature extractor.
self.w2v.wav2vec2.feature_extractor = nn.Identity()
else:
# We always freeze the feature extractor, which needs some special operations in DLAS
for p in self.w2v.wav2vec2.feature_extractor.parameters():
p.requires_grad = False
p.DO_NOT_TRAIN = True
2022-02-13 03:00:59 +00:00
if freeze_transformer:
# Also freeze the encoder here.
for p in list(self.w2v.wav2vec2.encoder.parameters()) + list(self.w2v.wav2vec2.feature_projection.parameters()):
p.requires_grad = False
p.DO_NOT_TRAIN = True
self.output_wer = output_wer
if output_wer:
self.last_pred = []
self.last_labels = []
def forward(self, audio, unaligned_tokens, wav_lengths, text_lengths, fea_extractor=None):
2022-02-13 03:00:59 +00:00
unaligned_tokens = unaligned_tokens[:, :text_lengths.max()]
audio = audio[:, :, :wav_lengths.max()]
2022-02-13 03:00:59 +00:00
attention_mask = torch.ones_like(audio).squeeze(1)
audio = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
audio = audio.squeeze(1) # Get rid of the channels; w2v re-adds them.
2022-02-13 03:00:59 +00:00
for b in range(audio.shape[0]):
if self.provide_attention_mask:
attention_mask[b, wav_lengths[b]:] = 0
2022-02-13 03:00:59 +00:00
unaligned_tokens[b, text_lengths[b]:] = -100
model_inp = fea_extractor if self.remove_feature_extractor else audio
outputs = self.w2v(input_values=model_inp, attention_mask=attention_mask, labels=unaligned_tokens)
2022-02-13 03:00:59 +00:00
if self.output_wer:
self.last_pred.append(torch.argmax(outputs.logits, dim=-1))
if len(self.last_pred) > 10:
self.last_pred = self.last_pred[1:]
self.last_labels.append(unaligned_tokens)
if len(self.last_labels) > 10:
self.last_labels = self.last_labels[1:]
return outputs.loss
def decode_ctc(self, output):
if isinstance(output, torch.Tensor):
output = output.tolist()
tokens = [token_group[0] for token_group in groupby(output)]
filtered_tokens = list(filter(lambda token: token != 0, tokens))
return filtered_tokens
def get_debug_values(self, step, net_name):
res = {}
if self.output_wer and step % 100 == 0:
from datasets import load_metric
wer_metric = load_metric("wer")
label_strings = []
pred_strings = []
for last_labels, last_pred in zip(self.last_labels, self.last_pred):
last_labels[last_labels == -100] = 0
2022-02-14 03:47:29 +00:00
label_strings.extend([only_letters(sequence_to_text(lbl)) for lbl in last_labels])
pred_strings.extend([only_letters(sequence_to_text(self.decode_ctc(pred))) for pred in last_pred])
2022-02-13 03:00:59 +00:00
wer = wer_metric.compute(predictions=pred_strings, references=label_strings)
res['wer'] = wer
print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}")
return res
def inference(self, audio):
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
logits = self.w2v(input_values=audio_norm.squeeze(1)).logits
pred = logits.argmax(dim=-1)
2022-02-14 03:47:29 +00:00
return [self.decode_ctc(p) for p in pred]
2022-02-13 03:00:59 +00:00
@register_model
def register_wav2vec_feature_extractor(opt_net, opt):
return Wav2VecFeatureExtractor(**opt_get(opt_net, ['kwargs'], {}))
2022-02-13 03:00:59 +00:00
@register_model
def register_wav2vec2_finetune(opt_net, opt):
return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
2022-02-14 03:47:29 +00:00
print(only_letters("Hello, world!"))
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True)
fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000]))
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea)
2022-02-13 03:00:59 +00:00
w2v.get_debug_values(0,"")
sd = torch.load('../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth')
w2v.load_state_dict(sd)
pred = w2v.inference(load_audio('Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', 16000).unsqueeze(0))
res = sequence_to_text(pred[0])
print(res)