|
|
@ -6,14 +6,11 @@ from torch import nn
|
|
|
|
# for god knows why it cannot "see" performer_pytorch
|
|
|
|
# for god knows why it cannot "see" performer_pytorch
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
|
prev_sys = copy(sys.path)
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)))
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)))
|
|
|
|
|
|
|
|
|
|
|
|
from performer_pytorch import PerformerLM
|
|
|
|
from performer_pytorch import PerformerLM
|
|
|
|
from autoregressive_wrapper import AutoregressiveWrapper
|
|
|
|
from autoregressive_wrapper import AutoregressiveWrapper
|
|
|
|
|
|
|
|
|
|
|
|
sys.path = prev_sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ENC_PREFIX = 'enc_'
|
|
|
|
ENC_PREFIX = 'enc_'
|
|
|
|
DEC_PREFIX = 'dec_'
|
|
|
|
DEC_PREFIX = 'dec_'
|
|
|
|
|
|
|
|
|
|
|
|