diff --git a/codes/models/asr/w2v_wrapper.py b/codes/models/asr/w2v_wrapper.py index 0348959e..237d93d8 100644 --- a/codes/models/asr/w2v_wrapper.py +++ b/codes/models/asr/w2v_wrapper.py @@ -20,11 +20,11 @@ class Wav2VecWrapper(nn.Module): """ Basic wrapper class that makes Wav2Vec2 usable by DLAS. """ - def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True): + def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True): super().__init__() self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model) # Perform some surgery to get the model we actually want. - self.w2v.wav2vec2.encoder.gradient_checkpointing = True + self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size) self.w2v.config.vocab_size = vocab_size self.w2v.config.pad_token_id = 0 diff --git a/codes/train.py b/codes/train.py index 5d443872..fb71e274 100644 --- a/codes/train.py +++ b/codes/train.py @@ -316,7 +316,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_mass.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_mass_initial_annealing.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index ccaac148..6ba9cf63 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -130,9 +130,9 @@ class ExtensibleTrainer(BaseModel): dnet = DistributedDataParallel(anet, delay_allreduce=True) else: from torch.nn.parallel.distributed import DistributedDataParallel - # Do NOT be tempted to put find_unused_parameters=True here. It will not work in the current incarnation of this trainer. - # Use all of your parameters in training, or delete them! - dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()]) + # Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is + # used and in a few other cases. But you can try it if you really want. + dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False)) # DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True, # which does not work with this trainer (as stated above). However, if the graph is not subject # to control flow alterations, you can set this option to allow gradient checkpointing. Beware that diff --git a/codes/trainer/eval/eval_wer.py b/codes/trainer/eval/eval_wer.py index 604965f0..b1b6ee2b 100644 --- a/codes/trainer/eval/eval_wer.py +++ b/codes/trainer/eval/eval_wer.py @@ -8,6 +8,7 @@ from data import create_dataset, create_dataloader from models.asr.w2v_wrapper import only_letters from models.tacotron2.text import sequence_to_text +# Fine-tuned target for w2v-large: 4.487% WER. class WerEvaluator(evaluator.Evaluator): """