DL-Art-School/codes/models/lucidrains/performer/performer_enc_dec.py
2022-01-09 22:10:07 -07:00

79 lines
2.9 KiB
Python

import re
import torch
from torch import nn
from performer_pytorch import PerformerLM
from autoregressive_wrapper import AutoregressiveWrapper
ENC_PREFIX = 'enc_'
DEC_PREFIX = 'dec_'
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return bool(re.match(f'^{prefix}', str))
def group_by_key_prefix(prefix, d):
return group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
def group_by_key_prefix_and_remove_prefix(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
def extract_enc_dec_kwargs(kwargs):
enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs)
dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs)
return enc_kwargs, dec_kwargs, kwargs
def extract_and_set_enc_dec_kwargs(kwargs):
enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs)
if 'mask' in enc_kwargs:
dec_kwargs.setdefault('context_mask', enc_kwargs['mask'])
return enc_kwargs, dec_kwargs, kwargs
class PerformerEncDec(nn.Module):
def __init__(
self,
dim,
ignore_index = 0,
pad_value = 0,
tie_token_embeds = False,
no_projection = False,
**kwargs
):
super().__init__()
enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)
assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'
enc_kwargs['dim'] = dec_kwargs['dim'] = dim
enc_kwargs['no_projection'] = dec_kwargs['no_projection'] = no_projection
dec_kwargs['causal'] = True
dec_kwargs['cross_attend'] = True
enc = PerformerLM(**enc_kwargs)
dec = PerformerLM(**dec_kwargs)
if tie_token_embeds:
enc.token_emb = dec.token_emb
self.enc = enc
self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value)
@torch.no_grad()
def generate(self, seq_in, seq_out_start, seq_len, **kwargs):
enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
encodings = self.enc(seq_in, return_encodings = True, **enc_kwargs)
return self.dec.generate(seq_out_start, seq_len, context = encodings, **{**dec_kwargs, **kwargs})
def forward(self, seq_in, seq_out, enc_mask = None, **kwargs):
enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
encodings = self.enc(seq_in, mask = enc_mask, return_encodings = True, **enc_kwargs)
return self.dec(seq_out, context = encodings, context_mask = enc_mask, **dec_kwargs)