haha
This commit is contained in:
parent
c56ce033d9
commit
3a6bd50322
|
@ -16,14 +16,17 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
|
||||||
from .retnet import RetNetDecoder, RetNetConfig
|
from .retnet import RetNetDecoder, RetNetConfig
|
||||||
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
|
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
|
||||||
|
|
||||||
from ..ext.interleaver import (
|
try:
|
||||||
CodebooksPatternProvider,
|
from ..ext.interleaver import (
|
||||||
DelayedPatternProvider,
|
CodebooksPatternProvider,
|
||||||
MusicLMPattern,
|
DelayedPatternProvider,
|
||||||
ParallelPatternProvider,
|
MusicLMPattern,
|
||||||
UnrolledPatternProvider,
|
ParallelPatternProvider,
|
||||||
VALLEPattern,
|
UnrolledPatternProvider,
|
||||||
)
|
VALLEPattern,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
|
||||||
|
@ -422,9 +425,8 @@ class Base(nn.Module):
|
||||||
if shift_targ_list:
|
if shift_targ_list:
|
||||||
targ_list = [*targ_list]
|
targ_list = [*targ_list]
|
||||||
for i in range(len(targ_list)):
|
for i in range(len(targ_list)):
|
||||||
targ_list[i] = targ_list[i].roll(-self.n_resp_levels, dims=0)
|
targ_list[i] = targ_list[i].roll(-1, dims=0)
|
||||||
for j in range(self.n_resp_levels):
|
targ_list[i][-1] = self.stop_token
|
||||||
targ_list[i][-j-1] = self.stop_token
|
|
||||||
|
|
||||||
# create the new target sequence to compute the loss against
|
# create the new target sequence to compute the loss against
|
||||||
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
|
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
|
||||||
|
@ -539,7 +541,7 @@ def example_usage():
|
||||||
'n_tokens': 1024,
|
'n_tokens': 1024,
|
||||||
'd_model': 1024,
|
'd_model': 1024,
|
||||||
'n_heads': 16,
|
'n_heads': 16,
|
||||||
'n_layers': 12,
|
'n_layers': 18,
|
||||||
}
|
}
|
||||||
models = { "ar": Base(**kwargs).to(device) }
|
models = { "ar": Base(**kwargs).to(device) }
|
||||||
|
|
||||||
|
@ -563,7 +565,7 @@ def example_usage():
|
||||||
qnt.to(device),
|
qnt.to(device),
|
||||||
]
|
]
|
||||||
|
|
||||||
def sample( filename, steps=400 ):
|
def sample( filename, steps=450 * 4 ):
|
||||||
AR = None
|
AR = None
|
||||||
|
|
||||||
engines.eval()
|
engines.eval()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user