A few mods to make wav2vec2 trainable with DDP on DLAS

This commit is contained in:
James Betker 2022-02-15 06:28:54 -07:00
parent 52b61b9f77
commit 2bdb515068
4 changed files with 7 additions and 6 deletions

View File

@ -20,11 +20,11 @@ class Wav2VecWrapper(nn.Module):
"""
Basic wrapper class that makes Wav2Vec2 usable by DLAS.
"""
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True):
def __init__(self, vocab_size=148, basis_model='facebook/wav2vec2-large', freeze_transformer=False, output_wer=True, checkpointing_enabled=True):
super().__init__()
self.w2v = Wav2Vec2ForCTC.from_pretrained(basis_model)
# Perform some surgery to get the model we actually want.
self.w2v.wav2vec2.encoder.gradient_checkpointing = True
self.w2v.wav2vec2.encoder.gradient_checkpointing = checkpointing_enabled
self.w2v.lm_head = nn.Linear(self.w2v.config.hidden_size, vocab_size)
self.w2v.config.vocab_size = vocab_size
self.w2v.config.pad_token_id = 0

View File

@ -316,7 +316,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_mass.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wav2vec_mass_initial_annealing.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -130,9 +130,9 @@ class ExtensibleTrainer(BaseModel):
dnet = DistributedDataParallel(anet, delay_allreduce=True)
else:
from torch.nn.parallel.distributed import DistributedDataParallel
# Do NOT be tempted to put find_unused_parameters=True here. It will not work in the current incarnation of this trainer.
# Use all of your parameters in training, or delete them!
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()])
# Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is
# used and in a few other cases. But you can try it if you really want.
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
# which does not work with this trainer (as stated above). However, if the graph is not subject
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that

View File

@ -8,6 +8,7 @@ from data import create_dataset, create_dataloader
from models.asr.w2v_wrapper import only_letters
from models.tacotron2.text import sequence_to_text
# Fine-tuned target for w2v-large: 4.487% WER.
class WerEvaluator(evaluator.Evaluator):
"""