diff --git a/codes/models/asr/__init__.py b/codes/models/asr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/asr/w2v_wrapper.py b/codes/models/asr/w2v_wrapper.py new file mode 100644 index 00000000..153a8385 --- /dev/null +++ b/codes/models/asr/w2v_wrapper.py @@ -0,0 +1,108 @@ +from itertools import groupby + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer + +from data.audio.unsupervised_audio_dataset import load_audio +from models.tacotron2.text import symbols, sequence_to_text +from trainer.networks import register_model +from utils.util import opt_get + + +class Wav2VecWrapper(nn.Module): + """ + Basic wrapper class that makes Wav2Vec2 usable by DLAS. + """ + def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True): + super().__init__() + self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) + # Perform some surgery to get the model we actually want. + self.w2v.wav2vec2.encoder.gradient_checkpointing = True + self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size) + self.w2v.config.vocab_size = vocab_size + self.w2v.config.pad_token_id = 0 + self.w2v.config.ctc_loss_reduction = 'mean' + self.w2v.config.apply_spec_augment = True + + # We always freeze the feature extractor, which needs some special operations in DLAS + for p in self.w2v.wav2vec2.feature_extractor.parameters(): + p.requires_grad = False + p.DO_NOT_TRAIN = True + if freeze_transformer: + # Also freeze the encoder here. + for p in list(self.w2v.wav2vec2.encoder.parameters()) + list(self.w2v.wav2vec2.feature_projection.parameters()): + p.requires_grad = False + p.DO_NOT_TRAIN = True + + self.output_wer = output_wer + if output_wer: + self.last_pred = [] + self.last_labels = [] + + def forward(self, audio, unaligned_tokens, wav_lengths, text_lengths): + audio = audio[:, :, :wav_lengths.max()] + unaligned_tokens = unaligned_tokens[:, :text_lengths.max()] + attention_mask = torch.ones_like(audio).squeeze(1) + for b in range(audio.shape[0]): + attention_mask[b, wav_lengths[b]:] = 0 + unaligned_tokens[b, text_lengths[b]:] = -100 + + audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + outputs = self.w2v(input_values=audio_norm.squeeze(1), attention_mask=attention_mask, labels=unaligned_tokens) + + if self.output_wer: + self.last_pred.append(torch.argmax(outputs.logits, dim=-1)) + if len(self.last_pred) > 10: + self.last_pred = self.last_pred[1:] + self.last_labels.append(unaligned_tokens) + if len(self.last_labels) > 10: + self.last_labels = self.last_labels[1:] + return outputs.loss + + def decode_ctc(self, output): + if isinstance(output, torch.Tensor): + output = output.tolist() + tokens = [token_group[0] for token_group in groupby(output)] + filtered_tokens = list(filter(lambda token: token != 0, tokens)) + return filtered_tokens + + def get_debug_values(self, step, net_name): + res = {} + if self.output_wer and step % 100 == 0: + from datasets import load_metric + wer_metric = load_metric("wer") + label_strings = [] + pred_strings = [] + for last_labels, last_pred in zip(self.last_labels, self.last_pred): + last_labels[last_labels == -100] = 0 + label_strings.extend([sequence_to_text(lbl) for lbl in last_labels]) + pred_strings.extend([sequence_to_text(self.decode_ctc(pred)) for pred in last_pred]) + wer = wer_metric.compute(predictions=pred_strings, references=label_strings) + res['wer'] = wer + print(f"Sample prediction: {pred_strings[0]} <=> {label_strings[0]}") + return res + + def inference(self, audio): + audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + logits = self.w2v(input_values=audio_norm.squeeze(1)).logits + pred = logits.argmax(dim=-1) + return self.decode_ctc(pred) + + +@register_model +def register_wav2vec2_finetune(opt_net, opt): + return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True) + loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50])) + w2v.get_debug_values(0,"") + + sd = torch.load('../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth') + w2v.load_state_dict(sd) + pred = w2v.inference(load_audio('Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav', 16000).unsqueeze(0)) + res = sequence_to_text(pred[0]) + print(res) diff --git a/codes/train.py b/codes/train.py index 334ffab2..5d443872 100644 --- a/codes/train.py +++ b/codes/train.py @@ -316,7 +316,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_mass.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()