ughghghhhh
This commit is contained in:
parent
ed373957e2
commit
2a1794c084
|
@ -58,9 +58,11 @@ def fold_inputs(
|
||||||
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
|
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
|
||||||
return (seq < stop).float() # (b t)
|
return (seq < stop).float() # (b t)
|
||||||
|
|
||||||
def list_to_tensor(x_list: list[Tensor]):
|
def list_to_tensor(x_list: list[Tensor], mask=True):
|
||||||
l = list(map(len, x_list))
|
l = list(map(len, x_list))
|
||||||
x = pad_sequence(x_list).t()
|
x = pad_sequence(x_list).t()
|
||||||
|
if not mask:
|
||||||
|
return x
|
||||||
|
|
||||||
m = _create_mask(l, x_list[0].device)
|
m = _create_mask(l, x_list[0].device)
|
||||||
m = m.to(x)
|
m = m.to(x)
|
||||||
|
@ -68,7 +70,7 @@ def fold_inputs(
|
||||||
|
|
||||||
def process_prom_or_task(i, prom):
|
def process_prom_or_task(i, prom):
|
||||||
if prom is None:
|
if prom is None:
|
||||||
return
|
return 0
|
||||||
|
|
||||||
if isinstance(prom, str):
|
if isinstance(prom, str):
|
||||||
task = get_task_symmap()[f'<{input}>']
|
task = get_task_symmap()[f'<{input}>']
|
||||||
|
@ -76,7 +78,8 @@ def fold_inputs(
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
return
|
|
||||||
|
return seq.shape[0] + 1
|
||||||
|
|
||||||
# deinterleaved
|
# deinterleaved
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
|
@ -99,6 +102,11 @@ def fold_inputs(
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
return seq.shape[0] + 1
|
||||||
|
|
||||||
|
def generate_position_ids( length, sep=True ):
|
||||||
|
return [ i for i in range( length + (1 if sep else 0) ) ]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
||||||
|
@ -109,6 +117,7 @@ def fold_inputs(
|
||||||
|
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
input_ids = [ [] for _ in range(batch_size) ]
|
input_ids = [ [] for _ in range(batch_size) ]
|
||||||
|
position_ids = [ [] for _ in range(batch_size) ]
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
||||||
|
@ -142,17 +151,23 @@ def fold_inputs(
|
||||||
seq = text + text_start
|
seq = text + text_start
|
||||||
else:
|
else:
|
||||||
seq = torch.tensor([text_start + text], device=device, dtype=dtype)
|
seq = torch.tensor([text_start + text], device=device, dtype=dtype)
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||||
|
|
||||||
# lang tokens
|
# lang tokens
|
||||||
for i, lang in enumerate(lang_list):
|
for i, lang in enumerate(lang_list):
|
||||||
if isinstance(lang, torch.Tensor):
|
if isinstance(lang, torch.Tensor):
|
||||||
seq = lang + lang_start
|
seq = lang + lang_start
|
||||||
else:
|
else:
|
||||||
seq = torch.tensor([lang_start + lang], device=device, dtype=dtype)
|
seq = torch.tensor([lang_start + lang], device=device, dtype=dtype)
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||||
|
|
||||||
# inject target quant_level
|
# inject target quant_level
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
|
@ -164,15 +179,20 @@ def fold_inputs(
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||||
|
|
||||||
# prom / task tokens
|
# prom / task tokens
|
||||||
for i, prom in enumerate(prom_list):
|
for i, prom in enumerate(prom_list):
|
||||||
# list of proms with a possible task token
|
# list of proms with a possible task token
|
||||||
|
length = 0
|
||||||
if isinstance(prom, list):
|
if isinstance(prom, list):
|
||||||
for p in prom:
|
for p in prom:
|
||||||
process_prom_or_task(i, p)
|
length += process_prom_or_task(i, p)
|
||||||
# raw tensor
|
# raw tensor
|
||||||
else:
|
else:
|
||||||
process_prom_or_task(i, prom)
|
length += process_prom_or_task(i, prom)
|
||||||
|
|
||||||
|
position_ids[i].append( generate_position_ids( length, sep=False ) )
|
||||||
|
|
||||||
# tone tokens
|
# tone tokens
|
||||||
for i, tone in enumerate(tone_list):
|
for i, tone in enumerate(tone_list):
|
||||||
|
@ -183,6 +203,8 @@ def fold_inputs(
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||||
|
|
||||||
# resp tokens
|
# resp tokens
|
||||||
for i, resp in enumerate(resp_list):
|
for i, resp in enumerate(resp_list):
|
||||||
# deinterleaved
|
# deinterleaved
|
||||||
|
@ -205,6 +227,8 @@ def fold_inputs(
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( stop )
|
input_ids[i].append( stop )
|
||||||
|
|
||||||
|
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||||
# interleaved
|
# interleaved
|
||||||
else:
|
else:
|
||||||
seq = resp.flatten().to(device=device, dtype=dtype)
|
seq = resp.flatten().to(device=device, dtype=dtype)
|
||||||
|
@ -213,6 +237,8 @@ def fold_inputs(
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( stop )
|
input_ids[i].append( stop )
|
||||||
|
|
||||||
|
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||||
|
|
||||||
# targ list
|
# targ list
|
||||||
for i, resp in enumerate(targ_list):
|
for i, resp in enumerate(targ_list):
|
||||||
|
@ -225,6 +251,8 @@ def fold_inputs(
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( stop )
|
input_ids[i].append( stop )
|
||||||
|
|
||||||
|
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||||
# interleaved
|
# interleaved
|
||||||
else:
|
else:
|
||||||
seq = resp.flatten().to(device=device, dtype=dtype)
|
seq = resp.flatten().to(device=device, dtype=dtype)
|
||||||
|
@ -233,11 +261,17 @@ def fold_inputs(
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( stop )
|
input_ids[i].append( stop )
|
||||||
|
|
||||||
|
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||||
|
|
||||||
for i, batch in enumerate(input_ids):
|
for i, batch in enumerate(input_ids):
|
||||||
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=dtype)
|
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=dtype)
|
||||||
|
position_ids[i] = torch.concat([ torch.tensor(ids, device=device, dtype=dtype) for ids in position_ids[i] ], dim=-1)
|
||||||
|
|
||||||
return list_to_tensor(input_ids)
|
input_ids, attention_mask = list_to_tensor(input_ids)
|
||||||
|
position_ids = list_to_tensor(position_ids, mask=False)
|
||||||
|
|
||||||
|
return input_ids, attention_mask, position_ids
|
||||||
|
|
||||||
# unfold from one unified token ID space to separate token spaces
|
# unfold from one unified token ID space to separate token spaces
|
||||||
# to-do: unfold at a specific RVQ level instead if requested
|
# to-do: unfold at a specific RVQ level instead if requested
|
||||||
|
|
|
@ -316,7 +316,7 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
|
|
||||||
def example_usage():
|
def example_usage():
|
||||||
# cfg.trainer.backend = "local"
|
cfg.trainer.backend = "local"
|
||||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||||
if cfg.audio_backend == "dac":
|
if cfg.audio_backend == "dac":
|
||||||
cfg.sample_rate = 44_100
|
cfg.sample_rate = 44_100
|
||||||
|
@ -398,7 +398,7 @@ def example_usage():
|
||||||
tasks = cfg.dataset.tasks_list
|
tasks = cfg.dataset.tasks_list
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 150 * len(tasks) * cfg.model.experimental.causal_size
|
steps = 150 * len(tasks) # * cfg.model.experimental.causal_size
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
|
|
@ -9,9 +9,9 @@ def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_
|
||||||
residual = None
|
residual = None
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
if self.gradient_checkpointing and hidden_states.requires_grad:
|
if self.gradient_checkpointing and hidden_states.requires_grad:
|
||||||
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False )
|
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, **mixer_kwargs, use_reentrant=False )
|
||||||
else:
|
else:
|
||||||
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params )
|
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params, **mixer_kwargs )
|
||||||
if not self.fused_add_norm:
|
if not self.fused_add_norm:
|
||||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||||
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
||||||
|
|
|
@ -701,12 +701,12 @@ class Base(nn.Module):
|
||||||
self.model = MambaMixelModel(
|
self.model = MambaMixelModel(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
n_layer=n_layers,
|
n_layer=n_layers*2,
|
||||||
d_intermediate=d_model*4,
|
d_intermediate=0, #d_model*2,
|
||||||
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": False} if self.arch_type == "mamba2" else {},
|
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if self.arch_type == "mamba2" else {},
|
||||||
rms_norm=True,
|
rms_norm=True,
|
||||||
fused_add_norm=True,
|
fused_add_norm=True,
|
||||||
residual_in_fp32=False,
|
residual_in_fp32=True,
|
||||||
#attn_layer_idx=attn_layer_idx,
|
#attn_layer_idx=attn_layer_idx,
|
||||||
#attn_cfg=attn_cfg,
|
#attn_cfg=attn_cfg,
|
||||||
#initializer_cfg=initializer_cfg,
|
#initializer_cfg=initializer_cfg,
|
||||||
|
@ -722,7 +722,7 @@ class Base(nn.Module):
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it)
|
use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it)
|
||||||
residual_in_fp32=False, # breaks for AMP inference
|
residual_in_fp32=True, # breaks for AMP inference
|
||||||
))
|
))
|
||||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||||
|
|
|
@ -59,6 +59,15 @@ class Model(LlmArchClass):
|
||||||
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
||||||
# vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
|
# vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
|
||||||
|
|
||||||
|
if hf_attention == "auto":
|
||||||
|
if AVAILABLE_ATTENTIONS:
|
||||||
|
hf_attention = AVAILABLE_ATTENTIONS[0]
|
||||||
|
else:
|
||||||
|
hf_attention = "eager"
|
||||||
|
|
||||||
|
if hf_attention == "xformers":
|
||||||
|
hf_attention = "mem_efficient"
|
||||||
|
|
||||||
text_start = 0
|
text_start = 0
|
||||||
text_end = text_start + config.text_tokens
|
text_end = text_start + config.text_tokens
|
||||||
|
|
||||||
|
@ -82,17 +91,17 @@ class Model(LlmArchClass):
|
||||||
|
|
||||||
vocab_size = resp_end
|
vocab_size = resp_end
|
||||||
|
|
||||||
if cfg.model.arch_type == "llama":
|
if config.arch_type == "llama":
|
||||||
super().__init__(config=LlamaConfig(
|
super().__init__(config=LlamaConfig(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
hidden_size=d_model,
|
hidden_size=d_model,
|
||||||
max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.max_levels * 60, # max-length of 60 seconds
|
max_position_embeddings=cfg.dataset.frames_per_second * config.max_levels * 60, # max-length of 60 seconds
|
||||||
intermediate_size=d_model*4,
|
intermediate_size=d_model*4,
|
||||||
num_hidden_layers=n_layers,
|
num_hidden_layers=n_layers,
|
||||||
num_attention_heads=n_heads,
|
num_attention_heads=n_heads,
|
||||||
attention_dropout=p_dropout,
|
attention_dropout=p_dropout,
|
||||||
num_key_value_heads=n_heads,
|
num_key_value_heads=n_heads,
|
||||||
sliding_window=cfg.dataset.frames_per_second * cfg.model.max_levels * 12,
|
sliding_window=cfg.dataset.frames_per_second * config.max_levels * 12,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
|
@ -103,7 +112,7 @@ class Model(LlmArchClass):
|
||||||
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||||
use_reentrant=False
|
use_reentrant=False
|
||||||
))
|
))
|
||||||
elif cfg.model.arch_type == "retnet":
|
elif config.arch_type == "retnet":
|
||||||
super().__init__(config=RetNetConfig(
|
super().__init__(config=RetNetConfig(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
decoder_embed_dim=d_model,
|
decoder_embed_dim=d_model,
|
||||||
|
@ -125,16 +134,16 @@ class Model(LlmArchClass):
|
||||||
|
|
||||||
decoder_normalize_before=True,
|
decoder_normalize_before=True,
|
||||||
))
|
))
|
||||||
elif cfg.model.arch_type in ["mamba","mamba2"]:
|
elif config.arch_type in ["mamba","mamba2"]:
|
||||||
super().__init__(config=MambaConfig(
|
super().__init__(config=MambaConfig(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
n_layer=n_layers,
|
n_layer=n_layers*2,
|
||||||
d_intermediate=d_model*4,
|
d_intermediate=0, # d_model*4,
|
||||||
ssm_cfg={"layer": "Mamba2"} if cfg.model.arch_type == "mamba2" else {},
|
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if config.arch_type == "mamba2" else {},
|
||||||
rms_norm=True,
|
rms_norm=True,
|
||||||
fused_add_norm=True,
|
fused_add_norm=True,
|
||||||
residual_in_fp32=True,
|
residual_in_fp32=False,
|
||||||
))
|
))
|
||||||
|
|
||||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||||
|
@ -163,8 +172,87 @@ class Model(LlmArchClass):
|
||||||
|
|
||||||
if "min_length" in kwargs:
|
if "min_length" in kwargs:
|
||||||
kwargs.pop("min_length")
|
kwargs.pop("min_length")
|
||||||
|
|
||||||
|
"""
|
||||||
|
if "position_ids" in kwargs:
|
||||||
|
kwargs.pop("position_ids")
|
||||||
|
|
||||||
|
if "max_new_tokens" in kwargs:
|
||||||
|
kwargs.pop("max_new_tokens")
|
||||||
|
|
||||||
|
if "max_length" not in kwargs:
|
||||||
|
kwargs["max_length"] = 500 * (self.hyper_config.resp_levels if self.hyper_config.experimental.interleave else 1)
|
||||||
|
|
||||||
|
if "num_last_tokens" not in kwargs:
|
||||||
|
kwargs["num_last_tokens"] = self.hyper_config.experimental.causal_size
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_ids = kwargs.pop("input_ids")
|
||||||
|
attention_mask = kwargs.pop("attention_mask", None)
|
||||||
|
position_ids = kwargs.pop("position_ids", None)
|
||||||
|
|
||||||
|
stop_token = kwargs.pop("eos_token_id", 3)
|
||||||
|
max_steps = kwargs.pop("max_new_tokens", 500)
|
||||||
|
|
||||||
|
device = input_ids.device
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
|
||||||
|
sequence_list = [ inputs for inputs in input_ids ]
|
||||||
|
position_list = [ positions for positions in position_ids ]
|
||||||
|
|
||||||
|
start_positions = [ inputs.shape[0] for inputs in input_ids ]
|
||||||
|
|
||||||
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
|
config = self.hyper_config
|
||||||
|
state = None
|
||||||
|
disable_tqdm = False
|
||||||
|
causal_size = config.experimental.causal_size
|
||||||
|
|
||||||
|
# get next in sequence
|
||||||
|
for n in trange(max_steps // max(1, causal_size), desc="AR", disable=disable_tqdm):
|
||||||
|
output = super().forward(
|
||||||
|
input_ids=torch.stack(sequence_list),
|
||||||
|
#attention_mask=attention_mask,
|
||||||
|
#past_key_values=state,
|
||||||
|
#position_ids=torch.stack(position_list),
|
||||||
|
#use_cache=False,
|
||||||
|
#return_dict=False
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = output[0]
|
||||||
|
# state = output[1]
|
||||||
|
|
||||||
|
r = [ logit[-causal_size:].argmax(dim=1) for logit in logits ]
|
||||||
|
|
||||||
|
# append tokens
|
||||||
|
for i, ri in enumerate(r):
|
||||||
|
if stop_token in ri:
|
||||||
|
stopped[i] = True
|
||||||
|
|
||||||
|
last_position_id = position_list[i][-1].item() + 1
|
||||||
|
sequence_list[i] = torch.cat([ sequence_list[i], ri.to(device) ], dim=0)
|
||||||
|
#position_list[i] = torch.cat([ position_list[i], torch.tensor([ last_position_id + _ for _ in range( ri.shape[0] ) ], device=device, dtype=torch.int32) ])
|
||||||
|
|
||||||
|
# stop token found
|
||||||
|
stopped |= r == stop_token
|
||||||
|
if stopped.all().item():
|
||||||
|
break
|
||||||
|
|
||||||
|
def _prune(l: Tensor, stop = stop_token):
|
||||||
|
indices = (l == stop).nonzero()
|
||||||
|
|
||||||
|
if len(indices) == 0:
|
||||||
|
return l
|
||||||
|
|
||||||
|
return l[: indices.min().item()]
|
||||||
|
|
||||||
|
sequence_list = [ _prune(seq[start_positions[i]:], stop_token) for i, seq in enumerate(sequence_list) ]
|
||||||
|
return torch.stack(sequence_list)
|
||||||
|
|
||||||
|
"""
|
||||||
return super().generate(*args, **kwargs)
|
return super().generate(*args, **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -188,14 +276,14 @@ class Model(LlmArchClass):
|
||||||
if training:
|
if training:
|
||||||
quant_levels = None if config.experimental.interleave else [ random.randint( 0 if "ar" in config.capabilities else 1, config.max_levels - 1) for _ in range(batch_size) ]
|
quant_levels = None if config.experimental.interleave else [ random.randint( 0 if "ar" in config.capabilities else 1, config.max_levels - 1) for _ in range(batch_size) ]
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(
|
input_ids, attention_mask, position_ids = fold_inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
prom_list=proms_list,
|
prom_list=proms_list,
|
||||||
resp_list=resps_list,
|
resp_list=resps_list,
|
||||||
targ_list=resps_list,
|
targ_list=resps_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
target_ids, target_attention_mask = fold_inputs(
|
target_ids, target_attention_mask, target_position_ids = fold_inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
prom_list=proms_list,
|
prom_list=proms_list,
|
||||||
resp_list=resps_list,
|
resp_list=resps_list,
|
||||||
|
@ -206,14 +294,16 @@ class Model(LlmArchClass):
|
||||||
return self.forward(
|
return self.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
labels=target_ids,
|
labels=target_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.experimental.interleave:
|
if config.experimental.interleave:
|
||||||
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
|
input_ids, attention_mask, position_ids = fold_inputs( text_list=text_list, prom_list=proms_list )
|
||||||
output = self.generate(
|
output = self.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
eos_token_id=3,
|
eos_token_id=3,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
|
@ -225,7 +315,7 @@ class Model(LlmArchClass):
|
||||||
for l in range(config.max_levels):
|
for l in range(config.max_levels):
|
||||||
quant_levels = [ l for _ in range(batch_size) ]
|
quant_levels = [ l for _ in range(batch_size) ]
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels)
|
input_ids, attention_mask, position_ids = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels)
|
||||||
min_length = 1
|
min_length = 1
|
||||||
for batch in input_ids:
|
for batch in input_ids:
|
||||||
min_length = max( min_length, batch.shape[0] + 1 )
|
min_length = max( min_length, batch.shape[0] + 1 )
|
||||||
|
@ -234,6 +324,7 @@ class Model(LlmArchClass):
|
||||||
output = self.generate(
|
output = self.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
eos_token_id=3,
|
eos_token_id=3,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
max_new_tokens=steps,
|
max_new_tokens=steps,
|
||||||
|
@ -273,10 +364,13 @@ class Model(LlmArchClass):
|
||||||
|
|
||||||
# i HATE the correct way
|
# i HATE the correct way
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if quant_levels is None:
|
||||||
|
quant_levels = [0 for _ in range(labels.shape[0])]
|
||||||
|
|
||||||
# predict the next token for AR, else predict in place
|
# predict the next token for AR, else predict in place
|
||||||
loss = sum([ F.cross_entropy(
|
loss = sum([ F.cross_entropy(
|
||||||
logit[:-1, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
|
logit[:-config.experimental.causal_size, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
|
||||||
label[1:] if quant_level == 0 or "nar" not in config.capabilities else label,
|
label[config.experimental.causal_size:] if quant_level == 0 or "nar" not in config.capabilities else label,
|
||||||
ignore_index=-100
|
ignore_index=-100
|
||||||
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
|
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
|
||||||
|
|
||||||
|
@ -372,7 +466,7 @@ def example_usage():
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
model = Model(**kwargs).to(device)
|
model = Model(**kwargs).to(device)
|
||||||
steps = 100 if cfg.model.experimental.interleave else 300
|
steps = 50 # 100 if cfg.model.experimental.interleave else 300
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user