This commit is contained in:
mrq 2023-09-03 21:36:58 -05:00
parent c56ce033d9
commit 3a6bd50322

View File

@ -16,14 +16,17 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
from .retnet import RetNetDecoder, RetNetConfig
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
from ..ext.interleaver import (
CodebooksPatternProvider,
DelayedPatternProvider,
MusicLMPattern,
ParallelPatternProvider,
UnrolledPatternProvider,
VALLEPattern,
)
try:
from ..ext.interleaver import (
CodebooksPatternProvider,
DelayedPatternProvider,
MusicLMPattern,
ParallelPatternProvider,
UnrolledPatternProvider,
VALLEPattern,
)
except Exception as e:
pass
from ..config import cfg
@ -422,9 +425,8 @@ class Base(nn.Module):
if shift_targ_list:
targ_list = [*targ_list]
for i in range(len(targ_list)):
targ_list[i] = targ_list[i].roll(-self.n_resp_levels, dims=0)
for j in range(self.n_resp_levels):
targ_list[i][-j-1] = self.stop_token
targ_list[i] = targ_list[i].roll(-1, dims=0)
targ_list[i][-1] = self.stop_token
# create the new target sequence to compute the loss against
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
@ -539,7 +541,7 @@ def example_usage():
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
'n_layers': 18,
}
models = { "ar": Base(**kwargs).to(device) }
@ -563,7 +565,7 @@ def example_usage():
qnt.to(device),
]
def sample( filename, steps=400 ):
def sample( filename, steps=450 * 4 ):
AR = None
engines.eval()