This commit is contained in:
mrq 2023-09-03 21:36:58 -05:00
parent c56ce033d9
commit 3a6bd50322

View File

@ -16,14 +16,17 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
from .retnet import RetNetDecoder, RetNetConfig from .retnet import RetNetDecoder, RetNetConfig
from .transformer import SinusoidalEmbedding, Block as TransformerBlock from .transformer import SinusoidalEmbedding, Block as TransformerBlock
from ..ext.interleaver import ( try:
CodebooksPatternProvider, from ..ext.interleaver import (
DelayedPatternProvider, CodebooksPatternProvider,
MusicLMPattern, DelayedPatternProvider,
ParallelPatternProvider, MusicLMPattern,
UnrolledPatternProvider, ParallelPatternProvider,
VALLEPattern, UnrolledPatternProvider,
) VALLEPattern,
)
except Exception as e:
pass
from ..config import cfg from ..config import cfg
@ -422,9 +425,8 @@ class Base(nn.Module):
if shift_targ_list: if shift_targ_list:
targ_list = [*targ_list] targ_list = [*targ_list]
for i in range(len(targ_list)): for i in range(len(targ_list)):
targ_list[i] = targ_list[i].roll(-self.n_resp_levels, dims=0) targ_list[i] = targ_list[i].roll(-1, dims=0)
for j in range(self.n_resp_levels): targ_list[i][-1] = self.stop_token
targ_list[i][-j-1] = self.stop_token
# create the new target sequence to compute the loss against # create the new target sequence to compute the loss against
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
@ -539,7 +541,7 @@ def example_usage():
'n_tokens': 1024, 'n_tokens': 1024,
'd_model': 1024, 'd_model': 1024,
'n_heads': 16, 'n_heads': 16,
'n_layers': 12, 'n_layers': 18,
} }
models = { "ar": Base(**kwargs).to(device) } models = { "ar": Base(**kwargs).to(device) }
@ -563,7 +565,7 @@ def example_usage():
qnt.to(device), qnt.to(device),
] ]
def sample( filename, steps=400 ): def sample( filename, steps=450 * 4 ):
AR = None AR = None
engines.eval() engines.eval()