import re import torch from torch import nn from performer_pytorch import PerformerLM from autoregressive_wrapper import AutoregressiveWrapper ENC_PREFIX = 'enc_' DEC_PREFIX = 'dec_' def group_dict_by_key(cond, d): return_val = [dict(),dict()] for key in d.keys(): match = bool(cond(key)) ind = int(not match) return_val[ind][key] = d[key] return (*return_val,) def string_begins_with(prefix, str): return bool(re.match(f'^{prefix}', str)) def group_by_key_prefix(prefix, d): return group_dict_by_key(lambda x: string_begins_with(prefix, x), d) def group_by_key_prefix_and_remove_prefix(prefix, d): kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs def extract_enc_dec_kwargs(kwargs): enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs) dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs) return enc_kwargs, dec_kwargs, kwargs def extract_and_set_enc_dec_kwargs(kwargs): enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs) if 'mask' in enc_kwargs: dec_kwargs.setdefault('context_mask', enc_kwargs['mask']) return enc_kwargs, dec_kwargs, kwargs class PerformerEncDec(nn.Module): def __init__( self, dim, ignore_index = 0, pad_value = 0, tie_token_embeds = False, no_projection = False, **kwargs ): super().__init__() enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs) assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder' enc_kwargs['dim'] = dec_kwargs['dim'] = dim enc_kwargs['no_projection'] = dec_kwargs['no_projection'] = no_projection dec_kwargs['causal'] = True dec_kwargs['cross_attend'] = True enc = PerformerLM(**enc_kwargs) dec = PerformerLM(**dec_kwargs) if tie_token_embeds: enc.token_emb = dec.token_emb self.enc = enc self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value) @torch.no_grad() def generate(self, seq_in, seq_out_start, seq_len, **kwargs): enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) encodings = self.enc(seq_in, return_encodings = True, **enc_kwargs) return self.dec.generate(seq_out_start, seq_len, context = encodings, **{**dec_kwargs, **kwargs}) def forward(self, seq_in, seq_out, enc_mask = None, **kwargs): enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs) encodings = self.enc(seq_in, mask = enc_mask, return_encodings = True, **enc_kwargs) return self.dec(seq_out, context = encodings, context_mask = enc_mask, **dec_kwargs)