79 lines
2.9 KiB
Python
79 lines
2.9 KiB
Python
import re
|
|
import torch
|
|
from torch import nn
|
|
from performer_pytorch import PerformerLM
|
|
from autoregressive_wrapper import AutoregressiveWrapper
|
|
|
|
ENC_PREFIX = 'enc_'
|
|
DEC_PREFIX = 'dec_'
|
|
|
|
def group_dict_by_key(cond, d):
|
|
return_val = [dict(),dict()]
|
|
for key in d.keys():
|
|
match = bool(cond(key))
|
|
ind = int(not match)
|
|
return_val[ind][key] = d[key]
|
|
return (*return_val,)
|
|
|
|
def string_begins_with(prefix, str):
|
|
return bool(re.match(f'^{prefix}', str))
|
|
|
|
def group_by_key_prefix(prefix, d):
|
|
return group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
|
|
|
|
def group_by_key_prefix_and_remove_prefix(prefix, d):
|
|
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
|
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
|
return kwargs_without_prefix, kwargs
|
|
|
|
def extract_enc_dec_kwargs(kwargs):
|
|
enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs)
|
|
dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs)
|
|
return enc_kwargs, dec_kwargs, kwargs
|
|
|
|
def extract_and_set_enc_dec_kwargs(kwargs):
|
|
enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs)
|
|
if 'mask' in enc_kwargs:
|
|
dec_kwargs.setdefault('context_mask', enc_kwargs['mask'])
|
|
return enc_kwargs, dec_kwargs, kwargs
|
|
|
|
class PerformerEncDec(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
ignore_index = 0,
|
|
pad_value = 0,
|
|
tie_token_embeds = False,
|
|
no_projection = False,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)
|
|
|
|
assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'
|
|
|
|
enc_kwargs['dim'] = dec_kwargs['dim'] = dim
|
|
enc_kwargs['no_projection'] = dec_kwargs['no_projection'] = no_projection
|
|
|
|
dec_kwargs['causal'] = True
|
|
dec_kwargs['cross_attend'] = True
|
|
|
|
enc = PerformerLM(**enc_kwargs)
|
|
dec = PerformerLM(**dec_kwargs)
|
|
|
|
if tie_token_embeds:
|
|
enc.token_emb = dec.token_emb
|
|
|
|
self.enc = enc
|
|
self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value)
|
|
|
|
@torch.no_grad()
|
|
def generate(self, seq_in, seq_out_start, seq_len, **kwargs):
|
|
enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
|
|
encodings = self.enc(seq_in, return_encodings = True, **enc_kwargs)
|
|
return self.dec.generate(seq_out_start, seq_len, context = encodings, **{**dec_kwargs, **kwargs})
|
|
|
|
def forward(self, seq_in, seq_out, enc_mask = None, **kwargs):
|
|
enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
|
|
encodings = self.enc(seq_in, mask = enc_mask, return_encodings = True, **enc_kwargs)
|
|
return self.dec(seq_out, context = encodings, context_mask = enc_mask, **dec_kwargs) |