diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index a1c26053..4f0c8b16 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -145,7 +145,7 @@ class ExtensibleTrainer(BaseModel): if opt_get(opt, ['ddp_static_graph'], False): dnet._set_static_graph() else: - dnet = DataParallel(anet, device_ids=opt['gpu_ids']) + dnet = DataParallel(anet, device_ids=[torch.cuda.current_device()]) if self.is_train: dnet.train() else: diff --git a/codes/trainer/eval/eval_wer.py b/codes/trainer/eval/eval_wer.py index b38ff8ae..1d89b399 100644 --- a/codes/trainer/eval/eval_wer.py +++ b/codes/trainer/eval/eval_wer.py @@ -100,10 +100,9 @@ if __name__ == '__main__': 'text_seq_key': 'padded_text', 'text_seq_lengths_key': 'text_lengths', } - model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-960h', freeze_transformer=True, checkpointing_enabled=False) - model.w2v = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-large-960h') - weights = torch.load('X:\\dlas\\experiments\\train_wav2vec_mass_large\\models\\0_wav2vec.pth') - #model.load_state_dict(weights) + model = Wav2VecWrapper(vocab_size=148, basis_model='facebook/wav2vec2-large-robust-ft-libri-960h', freeze_transformer=True, checkpointing_enabled=False) + weights = torch.load('X:\\dlas\\experiments/train_wav2vec_mass_diverse_initial_annealing_large_pt/models/7000_wav2vec.pth') + model.load_state_dict(weights) model = model.cuda() - eval = WerEvaluator(model, opt_eval, env, detokenizer_fn=fb_detokenize) + eval = WerEvaluator(model, opt_eval, env) print(eval.perform_eval()) \ No newline at end of file