DL-Art-School/dlas/models/lucidrains/performer/performer_enc_dec.py

103 lines
3.2 KiB
Python

# for god knows why it cannot "see" performer_pytorch
import os
import re
import sys
import torch
from torch import nn
from dlas.models.lucidrains.performer.autoregressive_wrapper import \
AutoregressiveWrapper
from dlas.models.lucidrains.performer.performer_pytorch import PerformerLM
sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)))
ENC_PREFIX = 'enc_'
DEC_PREFIX = 'dec_'
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return bool(re.match(f'^{prefix}', str))
def group_by_key_prefix(prefix, d):
return group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
def group_by_key_prefix_and_remove_prefix(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(
lambda x: string_begins_with(prefix, x), d)
kwargs_without_prefix = dict(
map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
def extract_enc_dec_kwargs(kwargs):
enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(
ENC_PREFIX, kwargs)
dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(
DEC_PREFIX, kwargs)
return enc_kwargs, dec_kwargs, kwargs
def extract_and_set_enc_dec_kwargs(kwargs):
enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs)
if 'mask' in enc_kwargs:
dec_kwargs.setdefault('context_mask', enc_kwargs['mask'])
return enc_kwargs, dec_kwargs, kwargs
class PerformerEncDec(nn.Module):
def __init__(
self,
dim,
ignore_index=0,
pad_value=0,
tie_token_embeds=False,
no_projection=False,
**kwargs
):
super().__init__()
enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)
assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'
enc_kwargs['dim'] = dec_kwargs['dim'] = dim
enc_kwargs['no_projection'] = dec_kwargs['no_projection'] = no_projection
dec_kwargs['causal'] = True
dec_kwargs['cross_attend'] = True
enc = PerformerLM(**enc_kwargs)
dec = PerformerLM(**dec_kwargs)
if tie_token_embeds:
enc.token_emb = dec.token_emb
self.enc = enc
self.dec = AutoregressiveWrapper(
dec, ignore_index=ignore_index, pad_value=pad_value)
@torch.no_grad()
def generate(self, seq_in, seq_out_start, seq_len, **kwargs):
enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
encodings = self.enc(seq_in, return_encodings=True, **enc_kwargs)
return self.dec.generate(seq_out_start, seq_len, context=encodings, **{**dec_kwargs, **kwargs})
def forward(self, seq_in, seq_out, enc_mask=None, **kwargs):
enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
encodings = self.enc(seq_in, mask=enc_mask,
return_encodings=True, **enc_kwargs)
return self.dec(seq_out, context=encodings, context_mask=enc_mask, **dec_kwargs)