forked from mrq/DL-Art-School
Attempt to make w2v play with DDP AND checkpointing
This commit is contained in:
parent
f3776f1992
commit
baf7b65566
|
@ -16,11 +16,34 @@ def only_letters(string):
|
||||||
return ''.join(filter(allowlist.__contains__, string.upper()))
|
return ''.join(filter(allowlist.__contains__, string.upper()))
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2VecFeatureExtractor(nn.Module):
|
||||||
|
"""
|
||||||
|
Basic wrapper that only does feature extraction. Useful to build out this portion of the model so it can be
|
||||||
|
operated through DDP.
|
||||||
|
"""
|
||||||
|
def __init__(self, basis_model='facebook/wav2vec2-large'):
|
||||||
|
super().__init__()
|
||||||
|
w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
|
||||||
|
self.extractor = w2v.wav2vec2.feature_extractor
|
||||||
|
|
||||||
|
for p in self.extractor.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
|
||||||
|
def forward(self, audio, wav_lengths):
|
||||||
|
with torch.no_grad():
|
||||||
|
audio = audio[:, :, :wav_lengths.max()]
|
||||||
|
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||||
|
return self.extractor(audio_norm.squeeze(1))
|
||||||
|
|
||||||
|
|
||||||
class Wav2VecWrapper(nn.Module):
|
class Wav2VecWrapper(nn.Module):
|
||||||
"""
|
"""
|
||||||
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
||||||
"""
|
"""
|
||||||
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True):
|
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True,
|
||||||
|
checkpointing_enabled=True, provide_attention_mask=False, spec_augment=True,
|
||||||
|
remove_feature_extractor=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.provide_attention_mask = provide_attention_mask
|
self.provide_attention_mask = provide_attention_mask
|
||||||
|
|
||||||
|
@ -32,11 +55,17 @@ class Wav2VecWrapper(nn.Module):
|
||||||
self.w2v.config.pad_token_id = 0
|
self.w2v.config.pad_token_id = 0
|
||||||
self.w2v.config.ctc_loss_reduction = 'sum'
|
self.w2v.config.ctc_loss_reduction = 'sum'
|
||||||
self.w2v.config.apply_spec_augment = spec_augment
|
self.w2v.config.apply_spec_augment = spec_augment
|
||||||
|
self.remove_feature_extractor = remove_feature_extractor
|
||||||
|
|
||||||
|
if remove_feature_extractor:
|
||||||
|
# The values passed in to the w2v model in this case are the outputs of the feature extractor.
|
||||||
|
self.w2v.wav2vec2.feature_extractor = nn.Identity()
|
||||||
|
else:
|
||||||
|
# We always freeze the feature extractor, which needs some special operations in DLAS
|
||||||
|
for p in self.w2v.wav2vec2.feature_extractor.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
|
||||||
# We always freeze the feature extractor, which needs some special operations in DLAS
|
|
||||||
for p in self.w2v.wav2vec2.feature_extractor.parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
p.DO_NOT_TRAIN = True
|
|
||||||
if freeze_transformer:
|
if freeze_transformer:
|
||||||
# Also freeze the encoder here.
|
# Also freeze the encoder here.
|
||||||
for p in list(self.w2v.wav2vec2.encoder.parameters()) + list(self.w2v.wav2vec2.feature_projection.parameters()):
|
for p in list(self.w2v.wav2vec2.encoder.parameters()) + list(self.w2v.wav2vec2.feature_projection.parameters()):
|
||||||
|
@ -48,19 +77,19 @@ class Wav2VecWrapper(nn.Module):
|
||||||
self.last_pred = []
|
self.last_pred = []
|
||||||
self.last_labels = []
|
self.last_labels = []
|
||||||
|
|
||||||
def forward(self, audio, unaligned_tokens, wav_lengths, text_lengths):
|
def forward(self, audio, unaligned_tokens, wav_lengths, text_lengths, fea_extractor=None):
|
||||||
audio = audio[:, :, :wav_lengths.max()]
|
|
||||||
unaligned_tokens = unaligned_tokens[:, :text_lengths.max()]
|
unaligned_tokens = unaligned_tokens[:, :text_lengths.max()]
|
||||||
|
audio = audio[:, :, :wav_lengths.max()]
|
||||||
attention_mask = torch.ones_like(audio).squeeze(1)
|
attention_mask = torch.ones_like(audio).squeeze(1)
|
||||||
|
audio = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
||||||
|
audio = audio.squeeze(1) # Get rid of the channels; w2v re-adds them.
|
||||||
for b in range(audio.shape[0]):
|
for b in range(audio.shape[0]):
|
||||||
attention_mask[b, wav_lengths[b]:] = 0
|
if self.provide_attention_mask:
|
||||||
|
attention_mask[b, wav_lengths[b]:] = 0
|
||||||
unaligned_tokens[b, text_lengths[b]:] = -100
|
unaligned_tokens[b, text_lengths[b]:] = -100
|
||||||
|
|
||||||
audio_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
|
model_inp = fea_extractor if self.remove_feature_extractor else audio
|
||||||
if self.provide_attention_mask:
|
outputs = self.w2v(input_values=model_inp, attention_mask=attention_mask, labels=unaligned_tokens)
|
||||||
outputs = self.w2v(input_values=audio_norm.squeeze(1), attention_mask=attention_mask, labels=unaligned_tokens)
|
|
||||||
else:
|
|
||||||
outputs = self.w2v(input_values=audio_norm.squeeze(1), labels=unaligned_tokens)
|
|
||||||
|
|
||||||
if self.output_wer:
|
if self.output_wer:
|
||||||
self.last_pred.append(torch.argmax(outputs.logits, dim=-1))
|
self.last_pred.append(torch.argmax(outputs.logits, dim=-1))
|
||||||
|
@ -101,6 +130,11 @@ class Wav2VecWrapper(nn.Module):
|
||||||
return [self.decode_ctc(p) for p in pred]
|
return [self.decode_ctc(p) for p in pred]
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_wav2vec_feature_extractor(opt_net, opt):
|
||||||
|
return Wav2VecFeatureExtractor(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_wav2vec2_finetune(opt_net, opt):
|
def register_wav2vec2_finetune(opt_net, opt):
|
||||||
return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {}))
|
return Wav2VecWrapper(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
@ -108,8 +142,10 @@ def register_wav2vec2_finetune(opt_net, opt):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print(only_letters("Hello, world!"))
|
print(only_letters("Hello, world!"))
|
||||||
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True)
|
fe = Wav2VecFeatureExtractor(basis_model='facebook/wav2vec2-large-960h')
|
||||||
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]))
|
w2v = Wav2VecWrapper(basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, remove_feature_extractor=True)
|
||||||
|
fea = fe(torch.randn(2,1,50000), torch.tensor([20000, 30000]))
|
||||||
|
loss = w2v(torch.randn(2,1,50000), torch.randint(0,40,(2,70)), torch.tensor([20000, 30000]), torch.tensor([35, 50]), fea)
|
||||||
w2v.get_debug_values(0,"")
|
w2v.get_debug_values(0,"")
|
||||||
|
|
||||||
sd = torch.load('../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth')
|
sd = torch.load('../experiments/train_wav2vec_mass_archived_r0/models/19500_wav2vec.pth')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user