DataParallel Fix
This commit is contained in:
parent
34001ad765
commit
bcba65c539
|
@ -145,7 +145,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
if opt_get(opt, ['ddp_static_graph'], False):
|
||||
dnet._set_static_graph()
|
||||
else:
|
||||
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
||||
dnet = DataParallel(anet, device_ids=[torch.cuda.current_device()])
|
||||
if self.is_train:
|
||||
dnet.train()
|
||||
else:
|
||||
|
|
|
@ -100,10 +100,9 @@ if __name__ == '__main__':
|
|||
'text_seq_key': 'padded_text',
|
||||
'text_seq_lengths_key': 'text_lengths',
|
||||
}
|
||||
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, checkpointing_enabled=False)
|
||||
model.w2v = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-large-960h')
|
||||
weights = torch.load('X:\\dlas\\experiments\\train_wav2vec_mass_large\\models\\0_wav2vec.pth')
|
||||
#model.load_state_dict(weights)
|
||||
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False)
|
||||
weights = torch.load('X:\\dlas\\experiments/train_wav2vec_mass_diverse_initial_annealing_large_pt/models/7000_wav2vec.pth')
|
||||
model.load_state_dict(weights)
|
||||
model = model.cuda()
|
||||
eval = WerEvaluator(model, opt_eval, env, detokenizer_fn=fb_detokenize)
|
||||
eval = WerEvaluator(model, opt_eval, env)
|
||||
print(eval.perform_eval())
|
Loading…
Reference in New Issue
Block a user