haha
This commit is contained in:
parent
c56ce033d9
commit
3a6bd50322
|
@ -16,14 +16,17 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
|
|||
from .retnet import RetNetDecoder, RetNetConfig
|
||||
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
|
||||
|
||||
from ..ext.interleaver import (
|
||||
CodebooksPatternProvider,
|
||||
DelayedPatternProvider,
|
||||
MusicLMPattern,
|
||||
ParallelPatternProvider,
|
||||
UnrolledPatternProvider,
|
||||
VALLEPattern,
|
||||
)
|
||||
try:
|
||||
from ..ext.interleaver import (
|
||||
CodebooksPatternProvider,
|
||||
DelayedPatternProvider,
|
||||
MusicLMPattern,
|
||||
ParallelPatternProvider,
|
||||
UnrolledPatternProvider,
|
||||
VALLEPattern,
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
from ..config import cfg
|
||||
|
||||
|
@ -422,9 +425,8 @@ class Base(nn.Module):
|
|||
if shift_targ_list:
|
||||
targ_list = [*targ_list]
|
||||
for i in range(len(targ_list)):
|
||||
targ_list[i] = targ_list[i].roll(-self.n_resp_levels, dims=0)
|
||||
for j in range(self.n_resp_levels):
|
||||
targ_list[i][-j-1] = self.stop_token
|
||||
targ_list[i] = targ_list[i].roll(-1, dims=0)
|
||||
targ_list[i][-1] = self.stop_token
|
||||
|
||||
# create the new target sequence to compute the loss against
|
||||
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
|
||||
|
@ -539,7 +541,7 @@ def example_usage():
|
|||
'n_tokens': 1024,
|
||||
'd_model': 1024,
|
||||
'n_heads': 16,
|
||||
'n_layers': 12,
|
||||
'n_layers': 18,
|
||||
}
|
||||
models = { "ar": Base(**kwargs).to(device) }
|
||||
|
||||
|
@ -563,7 +565,7 @@ def example_usage():
|
|||
qnt.to(device),
|
||||
]
|
||||
|
||||
def sample( filename, steps=400 ):
|
||||
def sample( filename, steps=450 * 4 ):
|
||||
AR = None
|
||||
|
||||
engines.eval()
|
||||
|
|
Loading…
Reference in New Issue
Block a user