forked from mrq/DL-Art-School
DataParallel Fix
This commit is contained in:
parent
34001ad765
commit
bcba65c539
|
@ -145,7 +145,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if opt_get(opt, ['ddp_static_graph'], False):
|
if opt_get(opt, ['ddp_static_graph'], False):
|
||||||
dnet._set_static_graph()
|
dnet._set_static_graph()
|
||||||
else:
|
else:
|
||||||
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
dnet = DataParallel(anet, device_ids=[torch.cuda.current_device()])
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
dnet.train()
|
dnet.train()
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -100,10 +100,9 @@ if __name__ == '__main__':
|
||||||
'text_seq_key': 'padded_text',
|
'text_seq_key': 'padded_text',
|
||||||
'text_seq_lengths_key': 'text_lengths',
|
'text_seq_lengths_key': 'text_lengths',
|
||||||
}
|
}
|
||||||
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, checkpointing_enabled=False)
|
model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False)
|
||||||
model.w2v = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-large-960h')
|
weights = torch.load('X:\\dlas\\experiments/train_wav2vec_mass_diverse_initial_annealing_large_pt/models/7000_wav2vec.pth')
|
||||||
weights = torch.load('X:\\dlas\\experiments\\train_wav2vec_mass_large\\models\\0_wav2vec.pth')
|
model.load_state_dict(weights)
|
||||||
#model.load_state_dict(weights)
|
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
eval = WerEvaluator(model, opt_eval, env, detokenizer_fn=fb_detokenize)
|
eval = WerEvaluator(model, opt_eval, env)
|
||||||
print(eval.perform_eval())
|
print(eval.perform_eval())
|
Loading…
Reference in New Issue
Block a user