DataParallel Fix

This commit is contained in:
James Betker 2022-02-19 20:36:35 -07:00
parent 34001ad765
commit bcba65c539
2 changed files with 5 additions and 6 deletions

View File

@ -145,7 +145,7 @@ class ExtensibleTrainer(BaseModel):
if opt_get(opt, ['ddp_static_graph'], False):
dnet._set_static_graph()
else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
dnet = DataParallel(anet, device_ids=[torch.cuda.current_device()])
if self.is_train:
dnet.train()
else:

View File

@ -100,10 +100,9 @@ if __name__ == '__main__':
'text_seq_key': 'padded_text',
'text_seq_lengths_key': 'text_lengths',
}
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, checkpointing_enabled=False)
model.w2v = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-large-960h')
weights = torch.load('X:\\dlas\\experiments\\train_wav2vec_mass_large\\models\\0_wav2vec.pth')
#model.load_state_dict(weights)
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False)
weights = torch.load('X:\\dlas\\experiments/train_wav2vec_mass_diverse_initial_annealing_large_pt/models/7000_wav2vec.pth')
model.load_state_dict(weights)
model = model.cuda()
eval = WerEvaluator(model, opt_eval, env, detokenizer_fn=fb_detokenize)
eval = WerEvaluator(model, opt_eval, env)
print(eval.perform_eval())