DataParallel Fix

This commit is contained in:
James Betker 2022-02-19 20:36:35 -07:00
parent 34001ad765
commit bcba65c539
2 changed files with 5 additions and 6 deletions

View File

@ -145,7 +145,7 @@ class ExtensibleTrainer(BaseModel):
if opt_get(opt, ['ddp_static_graph'], False): if opt_get(opt, ['ddp_static_graph'], False):
dnet._set_static_graph() dnet._set_static_graph()
else: else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids']) dnet = DataParallel(anet, device_ids=[torch.cuda.current_device()])
if self.is_train: if self.is_train:
dnet.train() dnet.train()
else: else:

View File

@ -100,10 +100,9 @@ if __name__ == '__main__':
'text_seq_key': 'padded_text', 'text_seq_key': 'padded_text',
'text_seq_lengths_key': 'text_lengths', 'text_seq_lengths_key': 'text_lengths',
} }
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, checkpointing_enabled=False) model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False)
model.w2v = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-large-960h') weights = torch.load('X:\\dlas\\experiments/train_wav2vec_mass_diverse_initial_annealing_large_pt/models/7000_wav2vec.pth')
weights = torch.load('X:\\dlas\\experiments\\train_wav2vec_mass_large\\models\\0_wav2vec.pth') model.load_state_dict(weights)
#model.load_state_dict(weights)
model = model.cuda() model = model.cuda()
eval = WerEvaluator(model, opt_eval, env, detokenizer_fn=fb_detokenize) eval = WerEvaluator(model, opt_eval, env)
print(eval.perform_eval()) print(eval.perform_eval())