tweaks
This commit is contained in:
parent
6d5bd0156a
commit
014e565c4b
|
@ -76,6 +76,14 @@ def fold_inputs(
|
|||
input_ids[i].append( sep )
|
||||
|
||||
offset = text_tokens
|
||||
# inject target quant_level
|
||||
if quant_levels is not None:
|
||||
for i, rvq in enumerate( quant_levels ):
|
||||
seq = torch.Tensor([offset + rvq]).to("cpu", dtype=torch.int64)
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
offset = text_tokens + audio_rvq_levels
|
||||
for i, prom in enumerate(prom_list):
|
||||
# deinterleaved
|
||||
if quant_levels is not None:
|
||||
|
@ -98,7 +106,7 @@ def fold_inputs(
|
|||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
offset = text_tokens + (audio_tokens * audio_rvq_levels)
|
||||
offset = text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels)
|
||||
|
||||
for i, resp in enumerate(resp_list):
|
||||
# deinterleaved
|
||||
|
@ -107,7 +115,10 @@ def fold_inputs(
|
|||
quant_level = quant_levels[i] - 1
|
||||
# way to signal we want to inference for rvq level 0
|
||||
# without it, it's a random chance for any level to be selected again
|
||||
|
||||
if quant_level < 0:
|
||||
continue
|
||||
|
||||
seq = sep
|
||||
else:
|
||||
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
|
||||
|
@ -192,10 +203,10 @@ def unfold_outputs(
|
|||
|
||||
if 0 <= id and id < text_tokens:
|
||||
text_list[i].append( id )
|
||||
elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels):
|
||||
prom_list[i].append( (id - text_tokens) % audio_tokens )
|
||||
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
|
||||
resp_list[i].append( (id - text_tokens) % audio_tokens )
|
||||
elif text_tokens + audio_rvq_levels <= id and id < text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels):
|
||||
prom_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
|
||||
elif text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) <= id:
|
||||
resp_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
|
||||
if not flushed:
|
||||
should_flush = True
|
||||
|
||||
|
|
|
@ -71,6 +71,7 @@ try:
|
|||
MambaMixelModel.forward = MambaMixelModel_forward
|
||||
|
||||
AVAILABLE_ARCHES.append("mamba")
|
||||
AVAILABLE_ARCHES.append("mamba2")
|
||||
except Exception as e:
|
||||
print("Error importing `mamba` arch:", e)
|
||||
pass
|
||||
|
@ -80,7 +81,7 @@ SELECTED_ARCH = cfg.model.arch_type
|
|||
if SELECTED_ARCH not in AVAILABLE_ARCHES:
|
||||
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available")
|
||||
|
||||
if SELECTED_ARCH == "mamba":
|
||||
if SELECTED_ARCH in ["mamba","mamba2"]:
|
||||
LlmArchClass = MambaLMHeadModel
|
||||
elif SELECTED_ARCH == "llama":
|
||||
LlmArchClass = LlamaForCausalLM
|
||||
|
@ -103,7 +104,8 @@ class Model(LlmArchClass):
|
|||
|
||||
hf_attention = config.attention if config is not None else None
|
||||
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
||||
vocab_size = 256 + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
|
||||
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
||||
vocab_size = 256 + cfg.model.max_levels + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
|
||||
|
||||
if SELECTED_ARCH == "llama":
|
||||
super().__init__(config=LlamaConfig(
|
||||
|
@ -148,12 +150,12 @@ class Model(LlmArchClass):
|
|||
|
||||
decoder_normalize_before=True,
|
||||
))
|
||||
elif SELECTED_ARCH == "mamba":
|
||||
elif SELECTED_ARCH in ["mamba","mamba2"]:
|
||||
super().__init__(config=MambaConfig(
|
||||
vocab_size=vocab_size,
|
||||
d_model=d_model,
|
||||
n_layer=n_layers*2,
|
||||
#ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan
|
||||
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if SELECTED_ARCH == "mamba2" else {},
|
||||
))
|
||||
|
||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||
|
@ -163,7 +165,7 @@ class Model(LlmArchClass):
|
|||
*args,
|
||||
**kwargs
|
||||
):
|
||||
if SELECTED_ARCH == "mamba":
|
||||
if SELECTED_ARCH in ["mamba","mamba2"]:
|
||||
kwargs["cg"] = True
|
||||
|
||||
if "attention_mask" in kwargs:
|
||||
|
@ -182,7 +184,7 @@ class Model(LlmArchClass):
|
|||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if SELECTED_ARCH == "mamba":
|
||||
if SELECTED_ARCH in ["mamba","mamba2"]:
|
||||
if "attention_mask" in kwargs:
|
||||
kwargs.pop("attention_mask")
|
||||
|
||||
|
@ -193,7 +195,7 @@ class Model(LlmArchClass):
|
|||
self.loss = dict(
|
||||
nll = output.loss,
|
||||
)
|
||||
elif SELECTED_ARCH == "mamba":
|
||||
elif SELECTED_ARCH in ["mamba","mamba2"]:
|
||||
if "labels" in kwargs:
|
||||
labels = kwargs.pop("labels")
|
||||
logits = output.logits
|
||||
|
@ -262,38 +264,15 @@ def example_usage():
|
|||
prom_list = prom_list[:1]
|
||||
resp_list = resp_list[:1]
|
||||
|
||||
if False:
|
||||
output_list = [ [] ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[0])
|
||||
unfolded = unfold_outputs( input_ids, quant_levels=[0])
|
||||
print( 0, "inputs:", input_ids.shape, input_ids )
|
||||
print( 0, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
||||
output_list[0].append( resp_list[0][:, 0] )
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[1])
|
||||
unfolded = unfold_outputs( input_ids, quant_levels=[1])
|
||||
print( 1, "inputs:", input_ids.shape, input_ids )
|
||||
print( 1, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
||||
output_list[0].append( resp_list[0][:, 1] )
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[2])
|
||||
unfolded = unfold_outputs( input_ids, quant_levels=[2])
|
||||
print( 2, "inputs:", input_ids.shape, input_ids )
|
||||
print( 2, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
||||
output_list[0].append( resp_list[0][:, 2] )
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[3])
|
||||
unfolded = unfold_outputs( input_ids, quant_levels=[3])
|
||||
print( 3, "inputs:", input_ids.shape, input_ids )
|
||||
print( 3, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
||||
output_list[0].append( resp_list[0][:, 3] )
|
||||
|
||||
return
|
||||
|
||||
kwargs = {}
|
||||
model = Model(**kwargs).to(device)
|
||||
steps = 50 if cfg.model.interleave else 250
|
||||
steps = 100
|
||||
if cfg.model.arch_type == "mamba2":
|
||||
steps = 100
|
||||
elif cfg.model.arch_type == "llama":
|
||||
steps = 500
|
||||
elif cfg.model.interleave:
|
||||
steps = 250
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||
|
|
Loading…
Reference in New Issue
Block a user