from itertools import groupby import torch import torch.nn as nn import torch.nn.functional as F from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Attention, Wav2Vec2Model from data.audio.unsupervised_audio_dataset import load_audio from models.tacotron2.text import symbols, sequence_to_text from trainer.networks import register_model from utils.util import opt_get def only_letters(string): allowlist = set(' ABCDEFGHIJKLMNOPQRSTUVWXYZ\'') return ''.join(filter(allowlist.__contains__, string.upper())) class Wav2VecFeatureExtractor(nn.Module): """ Basic wrapper that only does feature extraction. Useful to build out this portion of the model so it can be operated through DDP. """ def __init__(self, basis_model='facebook/wav2vec2-large'): super().__init__() w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) self.extractor = w2v.wav2vec2.feature_extractor for p in self.extractor.parameters(): p.requires_grad = False p.DO_NOT_TRAIN = True def forward(self, audio, wav_lengths): with torch.no_grad(): audio = audio[:, :, :wav_lengths.max()] audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) return self.extractor(audio_norm.squeeze(1)) class Wav2VecWrapper(nn.Module): """ Basic wrapper class that makes Wav2Vec2 usable by DLAS. """ def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True, remove_feature_extractor=False, ramp_dropout_mode=False, ramp_dropout_end=20000, ramp_dropout_min=.1, ramp_dropout_max=.5): super().__init__() self.provide_attention_mask = provide_attention_mask self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) # Perform some surgery to get the model we actually want. self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size) self.w2v.config.vocab_size = vocab_size self.w2v.config.pad_token_id = 0 self.w2v.config.ctc_loss_reduction = 'sum' self.w2v.config.apply_spec_augment = spec_augment self.remove_feature_extractor = remove_feature_extractor # This is a provision for distilling by ramping up dropout. self.ramp_dropout_mode = ramp_dropout_mode self.ramp_dropout_end = ramp_dropout_end self.ramp_dropout_min = ramp_dropout_min self.ramp_dropout_max = ramp_dropout_max self.current_dropout_rate = ramp_dropout_min if remove_feature_extractor: # The values passed in to the w2v model in this case are the outputs of the feature extractor. self.w2v.wav2vec2.feature_extractor = nn.Identity() else: # We always freeze the feature extractor, which needs some special operations in DLAS for p in self.w2v.wav2vec2.feature_extractor.parameters(): p.requires_grad = False p.DO_NOT_TRAIN = True if freeze_transformer: # Also freeze the encoder here. for p in list(self.w2v.wav2vec2.encoder.parameters()) + list(self.w2v.wav2vec2.feature_projection.parameters()): p.requires_grad = False p.DO_NOT_TRAIN = True self.output_wer = output_wer if output_wer: self.last_pred = [] self.last_labels = [] def forward(self, audio, unaligned_tokens, wav_lengths, text_lengths, fea_extractor=None): unaligned_tokens = unaligned_tokens[:, :text_lengths.max()] audio = audio[:, :, :wav_lengths.max()] attention_mask = torch.ones_like(audio).squeeze(1) audio = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) audio = audio.squeeze(1) # Get rid of the channels; w2v re-adds them. for b in range(audio.shape[0]): if self.provide_attention_mask: attention_mask[b, wav_lengths[b]:] = 0 unaligned_tokens[b, text_lengths[b]:] = -100 model_inp = fea_extractor if self.remove_feature_extractor else audio outputs = self.w2v(input_values=model_inp, attention_mask=attention_mask, labels=unaligned_tokens) if self.output_wer: self.last_pred.append(torch.argmax(outputs.logits, dim=-1)) if len(self.last_pred) > 10: self.last_pred = self.last_pred[1:] self.last_labels.append(unaligned_tokens) if len(self.last_labels) > 10: self.last_labels = self.last_labels[1:] return outputs.loss def decode_ctc(self, output): if isinstance(output, torch.Tensor): output = output.tolist() tokens = [token_group[0] for token_group in groupby(output)] filtered_tokens = list(filter(lambda token: token != 0, tokens)) return filtered_tokens def get_debug_values(self, step, net_name): res = {} if self.output_wer and step % 100 == 0: from datasets import load_metric wer_metric = load_metric("wer") label_strings = [] pred_strings = [] for last_labels, last_pred in zip(self.last_labels, self.last_pred): last_labels[last_labels == -100] = 0 label_strings.extend([only_letters(sequence_to_text(lbl)) for lbl in last_labels]) pred_strings.extend([only_letters(sequence_to_text(self.decode_ctc(pred))) for pred in last_pred]) wer = wer_metric.compute(predictions=pred_strings, references=label_strings) res['wer'] = wer print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}") if self.ramp_dropout_mode: res['dropout_rate'] = self.current_dropout_rate return res def inference(self, audio): audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) logits = self.w2v(input_values=audio_norm.squeeze(1)).logits pred = logits.argmax(dim=-1) return [self.decode_ctc(p) for p in pred] def inference_logits(self, audio): audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) logits = self.w2v(input_values=audio_norm.squeeze(1)).logits return logits def update_for_step(self, step, *args): if self.ramp_dropout_mode and step % 10 == 0: dropout_gap = self.ramp_dropout_max - self.ramp_dropout_min new_dropout_rate = self.ramp_dropout_min + dropout_gap * min(step / self.ramp_dropout_end, 1) self.current_dropout_rate = new_dropout_rate for name, module in self.w2v.named_modules(): if isinstance(module, nn.Dropout): module.p = new_dropout_rate elif isinstance(module, Wav2Vec2Attention): module.dropout = new_dropout_rate class Wav2VecBaseWrapper(nn.Module): def __init__(self, basis_model='facebook/wav2vec2-large'): super().__init__() self.w2v = Wav2Vec2Model.from_pretrained(basis_model) def forward(self, audio): audio = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) audio = audio.squeeze(1) # Get rid of the channels; w2v re-adds them. outputs = self.w2v(input_values=audio) return outputs.last_hidden_state @register_model def register_wav2vec_feature_extractor(opt_net, opt): return Wav2VecFeatureExtractor(**opt_get(opt_net, ['kwargs'], {})) @register_model def register_wav2vec2_finetune(opt_net, opt): return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {})) @register_model def register_wav2vec2(opt_net, opt): return Wav2VecBaseWrapper(**opt_get(opt_net, ['kwargs'], {})) if __name__ == '__main__': fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h') w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True, ramp_dropout_mode=True) w2v.update_for_step(8000) fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000])) loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea) w2v.get_debug_values(0,"") sd = torch.load('../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth') w2v.load_state_dict(sd) pred = w2v.inference(load_audio('Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', 16000).unsqueeze(0)) res = sequence_to_text(pred[0]) print(res)