vall-e/vall_e/models/retnet.py

68 lines
2.4 KiB
Python
Raw Normal View History

"""
# https://github.com/facebookresearch/fairseq/blob/main/fairseq/incremental_decoding_utils.py
# Copied directly because even having fairseq installed WILL break logging, why are corposhitters like this
"""
import uuid
from typing import Dict, Optional
from torch import Tensor
class FairseqIncrementalState(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_incremental_state()
def init_incremental_state(self):
self._incremental_state_id = str(uuid.uuid4())
def _get_full_incremental_state_key(self, key: str) -> str:
return "{}.{}".format(self._incremental_state_id, key)
def get_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
"""Helper for getting incremental state for an nn.Module."""
full_key = self._get_full_incremental_state_key(key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
value: Dict[str, Optional[Tensor]],
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = self._get_full_incremental_state_key(key)
incremental_state[full_key] = value
return incremental_state
def with_incremental_state(cls):
cls.__bases__ = (FairseqIncrementalState,) + tuple(
b for b in cls.__bases__ if b != FairseqIncrementalState
)
return cls
2023-08-02 21:53:35 +00:00
from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder as Decoder
@with_incremental_state
class RetNetDecoder(Decoder):
def forward(self, src_tokens, **kwargs):
return super().forward(src_tokens, **kwargs)
def max_positions(self):
return self.args.max_token_positions
def reorder_incremental_state( self, incremental_state, new_order ):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result