forked from mrq/DL-Art-School
A few mods to make wav2vec2 trainable with DDP on DLAS
This commit is contained in:
parent
52b61b9f77
commit
2bdb515068
|
@ -20,11 +20,11 @@ class Wav2VecWrapper(nn.Module):
|
||||||
"""
|
"""
|
||||||
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
|
||||||
"""
|
"""
|
||||||
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True):
|
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
|
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
|
||||||
# Perform some surgery to get the model we actually want.
|
# Perform some surgery to get the model we actually want.
|
||||||
self.w2v.wav2vec2.encoder.gradient_checkpointing = True
|
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
|
||||||
self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size)
|
self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size)
|
||||||
self.w2v.config.vocab_size = vocab_size
|
self.w2v.config.vocab_size = vocab_size
|
||||||
self.w2v.config.pad_token_id = 0
|
self.w2v.config.pad_token_id = 0
|
||||||
|
|
|
@ -316,7 +316,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_mass.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_mass_initial_annealing.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -130,9 +130,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||||
else:
|
else:
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
# Do NOT be tempted to put find_unused_parameters=True here. It will not work in the current incarnation of this trainer.
|
# Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is
|
||||||
# Use all of your parameters in training, or delete them!
|
# used and in a few other cases. But you can try it if you really want.
|
||||||
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()])
|
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
|
||||||
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
|
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
|
||||||
# which does not work with this trainer (as stated above). However, if the graph is not subject
|
# which does not work with this trainer (as stated above). However, if the graph is not subject
|
||||||
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that
|
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that
|
||||||
|
|
|
@ -8,6 +8,7 @@ from data import create_dataset, create_dataloader
|
||||||
from models.asr.w2v_wrapper import only_letters
|
from models.asr.w2v_wrapper import only_letters
|
||||||
from models.tacotron2.text import sequence_to_text
|
from models.tacotron2.text import sequence_to_text
|
||||||
|
|
||||||
|
# Fine-tuned target for w2v-large: 4.487% WER.
|
||||||
|
|
||||||
class WerEvaluator(evaluator.Evaluator):
|
class WerEvaluator(evaluator.Evaluator):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user