diff --git a/codes/models/lucidrains/performer/performer_enc_dec.py b/codes/models/lucidrains/performer/performer_enc_dec.py index 192f63a0..2ef411f4 100644 --- a/codes/models/lucidrains/performer/performer_enc_dec.py +++ b/codes/models/lucidrains/performer/performer_enc_dec.py @@ -1,9 +1,19 @@ + import re import torch from torch import nn + +# for god knows why it cannot "see" performer_pytorch +import os +import sys +prev_sys = copy(sys.path) +sys.path.insert(0, os.path.dirname(os.path.realpath(__file__))) + from performer_pytorch import PerformerLM from autoregressive_wrapper import AutoregressiveWrapper +sys.path = prev_sys + ENC_PREFIX = 'enc_' DEC_PREFIX = 'dec_'