fixes
This commit is contained in:
parent
186b93a77e
commit
c93d5863fd
|
@ -44,7 +44,7 @@ def fold_inputs(
|
|||
|
||||
text_tokens = 256,
|
||||
audio_tokens = 1024,
|
||||
audio_rvq_levels = cfg.model.prom_levels
|
||||
audio_rvq_levels = cfg.model.max_levels
|
||||
):
|
||||
def _create_mask(l, device):
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
|
@ -107,7 +107,7 @@ def unfold_outputs(
|
|||
|
||||
text_tokens = 256,
|
||||
audio_tokens = 1024,
|
||||
audio_rvq_levels = cfg.model.prom_levels
|
||||
audio_rvq_levels = cfg.model.max_levels
|
||||
):
|
||||
device = output_ids.device
|
||||
batch_size = output_ids.shape[0]
|
||||
|
@ -139,7 +139,7 @@ def unfold_outputs(
|
|||
bins[rvq].append( prom_list[i][pos] )
|
||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||
bins = bins[:nearest]
|
||||
prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
|
||||
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||
|
||||
|
||||
resp_len = len(resp_list[i])
|
||||
|
@ -152,9 +152,9 @@ def unfold_outputs(
|
|||
bins[rvq].append( resp_list[i][pos] )
|
||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||
bins = bins[:nearest]
|
||||
resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
|
||||
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||
|
||||
text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64)
|
||||
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)
|
||||
|
||||
return dict(
|
||||
text_list=text_list,
|
||||
|
|
|
@ -963,6 +963,7 @@ class RetNetModel(RetNetPreTrainedModel):
|
|||
retention_mask,
|
||||
forward_impl,
|
||||
past_key_value,
|
||||
use_reentrant=True,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer(
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from .ar_nar import AR_NAR
|
||||
from .experimental import Model as Experimental
|
||||
|
||||
def get_model(cfg, training=True):
|
||||
name = cfg.name
|
||||
|
||||
if not cfg.experimental:
|
||||
from .ar_nar import AR_NAR
|
||||
model = AR_NAR(
|
||||
n_tokens=cfg.tokens,
|
||||
d_model=cfg.dim,
|
||||
|
@ -21,6 +20,7 @@ def get_model(cfg, training=True):
|
|||
)
|
||||
model._cfg = cfg
|
||||
else:
|
||||
from .experimental import Model as Experimental
|
||||
model = Experimental(
|
||||
d_model=cfg.dim,
|
||||
n_layers=cfg.layers,
|
||||
|
|
|
@ -386,7 +386,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 200
|
||||
steps = 50
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||
|
@ -448,7 +448,7 @@ def example_usage():
|
|||
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, "./data/test.pth" )
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
|
||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
|
@ -459,16 +459,11 @@ def example_usage():
|
|||
|
||||
engine.eval()
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
|
||||
if cfg.audio_backend != "dac":
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
|
||||
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o, f"data/ar+nar.{i}.{name}.wav", device=device)
|
||||
_ = decode_to_file(o, f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||
|
||||
unload_model()
|
||||
|
||||
|
@ -484,7 +479,7 @@ def example_usage():
|
|||
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, "./data/test.pth" )
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
|
||||
sample("init", 5)
|
||||
train()
|
||||
|
|
|
@ -31,6 +31,15 @@ except Exception as e:
|
|||
print("Error importing `llama` arch:", e)
|
||||
pass
|
||||
|
||||
try:
|
||||
from .retnet_hf import RetNetConfig
|
||||
from ..ext.retnet_hf.modeling_retnet import RetNetForCausalLM
|
||||
|
||||
AVAILABLE_ARCHES.append("retnet")
|
||||
except Exception as e:
|
||||
print("Error importing `retnet` arch:", e)
|
||||
pass
|
||||
|
||||
try:
|
||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
|
||||
|
||||
|
@ -75,6 +84,8 @@ if SELECTED_ARCH == "mamba":
|
|||
LlmArchClass = MambaLMHeadModel
|
||||
elif SELECTED_ARCH == "llama":
|
||||
LlmArchClass = LlamaForCausalLM
|
||||
elif SELECTED_ARCH == "retnet":
|
||||
LlmArchClass = RetNetForCausalLM
|
||||
else:
|
||||
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available")
|
||||
|
||||
|
@ -92,18 +103,19 @@ class Model(LlmArchClass):
|
|||
|
||||
hf_attention = config.attention if config is not None else None
|
||||
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
||||
vocab_size = 256 + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
|
||||
|
||||
if SELECTED_ARCH == "llama":
|
||||
super().__init__(config=LlamaConfig(
|
||||
vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1,
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.prom_levels * 60, # max-length of 60 seconds
|
||||
max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.max_levels * 60, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout,
|
||||
num_key_value_heads=n_heads,
|
||||
sliding_window=cfg.dataset.frames_per_second * cfg.model.prom_levels * 12,
|
||||
sliding_window=cfg.dataset.frames_per_second * cfg.model.max_levels * 12,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
|
@ -114,9 +126,31 @@ class Model(LlmArchClass):
|
|||
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
elif SELECTED_ARCH == "retnet":
|
||||
super().__init__(config=RetNetConfig(
|
||||
vocab_size=vocab_size,
|
||||
decoder_embed_dim=d_model,
|
||||
decoder_value_embed_dim =d_model * 2,
|
||||
decoder_retention_heads=n_heads,
|
||||
decoder_ffn_embed_dim=d_model * 4,
|
||||
decoder_layers=n_layers,
|
||||
dropout=p_dropout,
|
||||
checkpoint_activations=gradient_checkpointing,
|
||||
activation_fn="gelu",
|
||||
use_layernorm=False,
|
||||
use_biases=False,
|
||||
use_glu=True,
|
||||
|
||||
#chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||
#recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||
#no_output_layer=True,
|
||||
#rotary_embedding_base=self.rotary_embedding_base, # 10000
|
||||
|
||||
decoder_normalize_before=True,
|
||||
))
|
||||
elif SELECTED_ARCH == "mamba":
|
||||
super().__init__(config=MambaConfig(
|
||||
vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1,
|
||||
vocab_size=vocab_size,
|
||||
d_model=d_model,
|
||||
n_layer=n_layers*2,
|
||||
#ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan
|
||||
|
@ -132,15 +166,15 @@ class Model(LlmArchClass):
|
|||
):
|
||||
output = super().forward(*args, **kwargs)
|
||||
|
||||
if SELECTED_ARCH == "llama":
|
||||
if SELECTED_ARCH in ["llama", "retnet"]:
|
||||
if output.loss is not None:
|
||||
self.loss = dict(
|
||||
nll = output.loss,
|
||||
)
|
||||
elif SELECTED_ARCH == "mamba":
|
||||
if "labels" in kwargs:
|
||||
logits = output.logits
|
||||
labels = kwargs.pop("labels")
|
||||
logits = output.logits
|
||||
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
|
@ -183,7 +217,7 @@ def example_usage():
|
|||
|
||||
def _load_quants(path) -> Tensor:
|
||||
qnt = np.load(path, allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.max_levels, :].t().to(torch.int16)
|
||||
|
||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
|
||||
|
@ -278,7 +312,7 @@ def example_usage():
|
|||
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=cfg.model.prom_levels*cfg.dataset.frames_per_second*60 ):
|
||||
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*60 ):
|
||||
engine.eval()
|
||||
if SELECTED_ARCH == "mamba":
|
||||
output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# https://github.com/syncdoth/RetNet/
|
||||
from ..ext.retnet_hf.configuration_retnet import RetNetConfig
|
||||
from ..ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder
|
||||
from ..ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder, RetNetForCausalLM
|
||||
|
||||
# things we're overriding or required to override
|
||||
from ..ext.retnet_hf.modeling_retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, split_heads, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
|
||||
|
@ -32,77 +32,74 @@ def FeedForwardNetwork_init(
|
|||
|
||||
FeedForwardNetwork.__init__ = FeedForwardNetwork_init
|
||||
|
||||
# removes embed_tokens
|
||||
def RetNetModel_init(
|
||||
self,
|
||||
config: RetNetConfig,
|
||||
embed_tokens: torch.nn.Embedding = None,
|
||||
tensor_parallel: bool = False,
|
||||
):
|
||||
super(RetNetDecoder, self).__init__(config)
|
||||
self.config = config
|
||||
self,
|
||||
config: RetNetConfig,
|
||||
embed_tokens: torch.nn.Embedding = None,
|
||||
tensor_parallel: bool = False,
|
||||
):
|
||||
super(RetNetDecoder, self).__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.dropout_module = torch.nn.Dropout(config.dropout)
|
||||
self.dropout_module = torch.nn.Dropout(config.dropout)
|
||||
|
||||
self.embed_dim = config.decoder_embed_dim
|
||||
self.embed_scale = (
|
||||
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
||||
self.embed_dim = config.decoder_embed_dim
|
||||
self.embed_scale = (
|
||||
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
||||
)
|
||||
|
||||
if embed_tokens is None and config.vocab_size:
|
||||
embed_tokens = torch.nn.Embedding(
|
||||
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
||||
)
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
if config.layernorm_embedding:
|
||||
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
||||
self.layers = torch.nn.ModuleList([])
|
||||
|
||||
for i in range(config.decoder_layers):
|
||||
self.layers.append(
|
||||
RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel)
|
||||
)
|
||||
|
||||
"""
|
||||
if embed_tokens is None:
|
||||
embed_tokens = torch.nn.Embedding(
|
||||
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
||||
)
|
||||
"""
|
||||
self.embed_tokens = None
|
||||
self.decoder_layers = len(self.layers)
|
||||
|
||||
if config.layernorm_embedding:
|
||||
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
if config.decoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
self.layers = torch.nn.ModuleList([])
|
||||
self.retnet_rel_pos = RetNetRelPos(config)
|
||||
self.recurrent_chunk_size = config.recurrent_chunk_size
|
||||
|
||||
for i in range(config.decoder_layers):
|
||||
self.layers.append(
|
||||
RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel)
|
||||
)
|
||||
if config.deepnorm:
|
||||
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
|
||||
for name, p in self.named_parameters():
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.div_(init_scale)
|
||||
|
||||
self.decoder_layers = len(self.layers)
|
||||
if config.subln and not config.use_glu:
|
||||
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
|
||||
for name, p in self.named_parameters():
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.mul_(init_scale)
|
||||
|
||||
if config.decoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
self.retnet_rel_pos = RetNetRelPos(config)
|
||||
self.recurrent_chunk_size = config.recurrent_chunk_size
|
||||
|
||||
if config.deepnorm:
|
||||
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
|
||||
for name, p in self.named_parameters():
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.div_(init_scale)
|
||||
|
||||
if config.subln and not config.use_glu:
|
||||
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
|
||||
for name, p in self.named_parameters():
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.mul_(init_scale)
|
||||
|
||||
self.gradient_checkpointing = True
|
||||
self.post_init()
|
||||
self.gradient_checkpointing = True
|
||||
self.post_init()
|
||||
|
||||
RetNetDecoder.__init__ = RetNetModel_init
|
||||
|
||||
|
|
|
@ -36,83 +36,81 @@ FeedForwardNetwork.__init__ = FeedForwardNetwork_init
|
|||
|
||||
# removes embed_tokens
|
||||
def RetNetModel_init(
|
||||
self, config, embed_tokens=None, output_projection=None, **kwargs
|
||||
):
|
||||
super(RetNetDecoder, self).__init__(**kwargs)
|
||||
self.config = config
|
||||
self, config, embed_tokens=None, output_projection=None, **kwargs
|
||||
):
|
||||
super(RetNetDecoder, self).__init__(**kwargs)
|
||||
self.config = config
|
||||
|
||||
self.dropout_module = torch.nn.Dropout(config.dropout)
|
||||
self.dropout_module = torch.nn.Dropout(config.dropout)
|
||||
|
||||
self.embed_dim = config.decoder_embed_dim
|
||||
self.embed_scale = (
|
||||
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
||||
self.embed_dim = config.decoder_embed_dim
|
||||
self.embed_scale = (
|
||||
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
||||
)
|
||||
|
||||
if embed_tokens is None and config.vocab_size:
|
||||
embed_tokens = torch.nn.Embedding(
|
||||
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
||||
)
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
|
||||
self.output_projection = self.build_output_projection(config)
|
||||
else:
|
||||
self.output_projection = output_projection
|
||||
|
||||
if config.layernorm_embedding:
|
||||
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
||||
self.layers = torch.nn.ModuleList([])
|
||||
|
||||
for i in range(config.decoder_layers):
|
||||
layer = self.build_decoder_layer(
|
||||
config,
|
||||
depth=i,
|
||||
)
|
||||
"""
|
||||
if embed_tokens is None:
|
||||
embed_tokens = torch.nn.Embedding(
|
||||
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
||||
)
|
||||
if config.checkpoint_activations:
|
||||
layer = checkpoint_wrapper(layer)
|
||||
"""
|
||||
self.embed_tokens = None
|
||||
self.layers.append(layer)
|
||||
|
||||
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
|
||||
self.output_projection = self.build_output_projection(config)
|
||||
else:
|
||||
self.output_projection = output_projection
|
||||
self.num_layers = len(self.layers)
|
||||
|
||||
if config.layernorm_embedding:
|
||||
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
if config.decoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
self.layers = torch.nn.ModuleList([])
|
||||
self.retnet_rel_pos = RetNetRelPos(config)
|
||||
self.chunkwise_recurrent = config.chunkwise_recurrent
|
||||
self.recurrent_chunk_size = config.recurrent_chunk_size
|
||||
|
||||
for i in range(config.decoder_layers):
|
||||
layer = self.build_decoder_layer(
|
||||
config,
|
||||
depth=i,
|
||||
)
|
||||
"""
|
||||
if config.checkpoint_activations:
|
||||
layer = checkpoint_wrapper(layer)
|
||||
"""
|
||||
self.layers.append(layer)
|
||||
if config.deepnorm:
|
||||
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
|
||||
for name, p in self.named_parameters():
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.div_(init_scale)
|
||||
|
||||
self.num_layers = len(self.layers)
|
||||
if config.subln and not config.use_glu:
|
||||
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
|
||||
for name, p in self.named_parameters():
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.mul_(init_scale)
|
||||
|
||||
if config.decoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
self.retnet_rel_pos = RetNetRelPos(config)
|
||||
self.chunkwise_recurrent = config.chunkwise_recurrent
|
||||
self.recurrent_chunk_size = config.recurrent_chunk_size
|
||||
|
||||
if config.deepnorm:
|
||||
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
|
||||
for name, p in self.named_parameters():
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.div_(init_scale)
|
||||
|
||||
if config.subln and not config.use_glu:
|
||||
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
|
||||
for name, p in self.named_parameters():
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.mul_(init_scale)
|
||||
|
||||
self.gradient_checkpointing = True
|
||||
self.gradient_checkpointing = True
|
||||
|
||||
RetNetDecoder.__init__ = RetNetModel_init
|
||||
|
||||
|
|
|
@ -109,9 +109,9 @@ def run_eval(engines, eval_name, dl):
|
|||
if engine.hyper_config.experimental:
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
prom_list=batch["proms"],
|
||||
)
|
||||
output = engine.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
|
||||
output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
|
||||
resps_list = unfold_outputs( output )["resp_list"]
|
||||
else:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
|
|
Loading…
Reference in New Issue
Block a user