Attempt to make w2v play with DDP AND checkpointing

This commit is contained in:
James Betker 2022-02-18 18:47:11 -07:00
parent f3776f1992
commit baf7b65566

View File

@ -16,11 +16,34 @@ def only_letters(string):
return ''.join(filter(allowlist.__contains__, string.upper()))
class Wav2VecFeatureExtractor(nn.Module):
"""
Basic wrapper that only does feature extraction. Useful to build out this portion of the model so it can be
operated through DDP.
"""
def __init__(self, basis_model='facebook/wav2vec2-large'):
super().__init__()
w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
self.extractor = w2v.wav2vec2.feature_extractor
for p in self.extractor.parameters():
p.requires_grad = False
p.DO_NOT_TRAIN = True
def forward(self, audio, wav_lengths):
with torch.no_grad():
audio = audio[:, :, :wav_lengths.max()]
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
return self.extractor(audio_norm.squeeze(1))
class Wav2VecWrapper(nn.Module):
"""
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
"""
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True):
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
remove_feature_extractor=False):
super().__init__()
self.provide_attention_mask = provide_attention_mask
@ -32,11 +55,17 @@ class Wav2VecWrapper(nn.Module):
self.w2v.config.pad_token_id = 0
self.w2v.config.ctc_loss_reduction = 'sum'
self.w2v.config.apply_spec_augment = spec_augment
self.remove_feature_extractor = remove_feature_extractor
if remove_feature_extractor:
# The values passed in to the w2v model in this case are the outputs of the feature extractor.
self.w2v.wav2vec2.feature_extractor = nn.Identity()
else:
# We always freeze the feature extractor, which needs some special operations in DLAS
for p in self.w2v.wav2vec2.feature_extractor.parameters():
p.requires_grad = False
p.DO_NOT_TRAIN = True
# We always freeze the feature extractor, which needs some special operations in DLAS
for p in self.w2v.wav2vec2.feature_extractor.parameters():
p.requires_grad = False
p.DO_NOT_TRAIN = True
if freeze_transformer:
# Also freeze the encoder here.
for p in list(self.w2v.wav2vec2.encoder.parameters()) + list(self.w2v.wav2vec2.feature_projection.parameters()):
@ -48,19 +77,19 @@ class Wav2VecWrapper(nn.Module):
self.last_pred = []
self.last_labels = []
def forward(self, audio, unaligned_tokens, wav_lengths, text_lengths):
audio = audio[:, :, :wav_lengths.max()]
def forward(self, audio, unaligned_tokens, wav_lengths, text_lengths, fea_extractor=None):
unaligned_tokens = unaligned_tokens[:, :text_lengths.max()]
audio = audio[:, :, :wav_lengths.max()]
attention_mask = torch.ones_like(audio).squeeze(1)
audio = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
audio = audio.squeeze(1) # Get rid of the channels; w2v re-adds them.
for b in range(audio.shape[0]):
attention_mask[b, wav_lengths[b]:] = 0
if self.provide_attention_mask:
attention_mask[b, wav_lengths[b]:] = 0
unaligned_tokens[b, text_lengths[b]:] = -100
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
if self.provide_attention_mask:
outputs = self.w2v(input_values=audio_norm.squeeze(1), attention_mask=attention_mask, labels=unaligned_tokens)
else:
outputs = self.w2v(input_values=audio_norm.squeeze(1), labels=unaligned_tokens)
model_inp = fea_extractor if self.remove_feature_extractor else audio
outputs = self.w2v(input_values=model_inp, attention_mask=attention_mask, labels=unaligned_tokens)
if self.output_wer:
self.last_pred.append(torch.argmax(outputs.logits, dim=-1))
@ -101,6 +130,11 @@ class Wav2VecWrapper(nn.Module):
return [self.decode_ctc(p) for p in pred]
@register_model
def register_wav2vec_feature_extractor(opt_net, opt):
return Wav2VecFeatureExtractor(**opt_get(opt_net, ['kwargs'], {}))
@register_model
def register_wav2vec2_finetune(opt_net, opt):
return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {}))
@ -108,8 +142,10 @@ def register_wav2vec2_finetune(opt_net, opt):
if __name__ == '__main__':
print(only_letters("Hello, world!"))
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True)
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]))
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True)
fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000]))
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea)
w2v.get_debug_values(0,"")
sd = torch.load('../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth')