68 lines
2.4 KiB
Python
Executable File
68 lines
2.4 KiB
Python
Executable File
"""
|
|
# https://github.com/facebookresearch/fairseq/blob/main/fairseq/incremental_decoding_utils.py
|
|
# Copied directly because even having fairseq installed WILL break logging, why are corposhitters like this
|
|
"""
|
|
|
|
import uuid
|
|
from typing import Dict, Optional
|
|
|
|
from torch import Tensor
|
|
|
|
class FairseqIncrementalState(object):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.init_incremental_state()
|
|
|
|
def init_incremental_state(self):
|
|
self._incremental_state_id = str(uuid.uuid4())
|
|
|
|
def _get_full_incremental_state_key(self, key: str) -> str:
|
|
return "{}.{}".format(self._incremental_state_id, key)
|
|
|
|
def get_incremental_state(
|
|
self,
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
key: str,
|
|
) -> Optional[Dict[str, Optional[Tensor]]]:
|
|
"""Helper for getting incremental state for an nn.Module."""
|
|
full_key = self._get_full_incremental_state_key(key)
|
|
if incremental_state is None or full_key not in incremental_state:
|
|
return None
|
|
return incremental_state[full_key]
|
|
|
|
def set_incremental_state(
|
|
self,
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
key: str,
|
|
value: Dict[str, Optional[Tensor]],
|
|
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
|
"""Helper for setting incremental state for an nn.Module."""
|
|
if incremental_state is not None:
|
|
full_key = self._get_full_incremental_state_key(key)
|
|
incremental_state[full_key] = value
|
|
return incremental_state
|
|
|
|
|
|
def with_incremental_state(cls):
|
|
cls.__bases__ = (FairseqIncrementalState,) + tuple(
|
|
b for b in cls.__bases__ if b != FairseqIncrementalState
|
|
)
|
|
return cls
|
|
|
|
|
|
from torchscale.architecture.config import RetNetConfig
|
|
from torchscale.architecture.retnet import RetNetDecoder as Decoder
|
|
|
|
@with_incremental_state
|
|
class RetNetDecoder(Decoder):
|
|
def forward(self, src_tokens, **kwargs):
|
|
return super().forward(src_tokens, **kwargs)
|
|
|
|
def max_positions(self):
|
|
return self.args.max_token_positions
|
|
|
|
def reorder_incremental_state( self, incremental_state, new_order ):
|
|
for module in incremental_state:
|
|
for key in incremental_state[module]:
|
|
result = incremental_state[module][key].index_select(0, new_order)
|
|
incremental_state[module][key] = result |