This commit is contained in:
mrq 2024-06-04 20:41:13 -05:00
parent 6d5bd0156a
commit 014e565c4b
2 changed files with 32 additions and 42 deletions

View File

@ -76,6 +76,14 @@ def fold_inputs(
input_ids[i].append( sep )
offset = text_tokens
# inject target quant_level
if quant_levels is not None:
for i, rvq in enumerate( quant_levels ):
seq = torch.Tensor([offset + rvq]).to("cpu", dtype=torch.int64)
input_ids[i].append( seq )
input_ids[i].append( sep )
offset = text_tokens + audio_rvq_levels
for i, prom in enumerate(prom_list):
# deinterleaved
if quant_levels is not None:
@ -98,7 +106,7 @@ def fold_inputs(
input_ids[i].append( seq )
input_ids[i].append( sep )
offset = text_tokens + (audio_tokens * audio_rvq_levels)
offset = text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels)
for i, resp in enumerate(resp_list):
# deinterleaved
@ -107,7 +115,10 @@ def fold_inputs(
quant_level = quant_levels[i] - 1
# way to signal we want to inference for rvq level 0
# without it, it's a random chance for any level to be selected again
if quant_level < 0:
continue
seq = sep
else:
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
@ -192,10 +203,10 @@ def unfold_outputs(
if 0 <= id and id < text_tokens:
text_list[i].append( id )
elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels):
prom_list[i].append( (id - text_tokens) % audio_tokens )
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
resp_list[i].append( (id - text_tokens) % audio_tokens )
elif text_tokens + audio_rvq_levels <= id and id < text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels):
prom_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
elif text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) <= id:
resp_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
if not flushed:
should_flush = True

View File

@ -71,6 +71,7 @@ try:
MambaMixelModel.forward = MambaMixelModel_forward
AVAILABLE_ARCHES.append("mamba")
AVAILABLE_ARCHES.append("mamba2")
except Exception as e:
print("Error importing `mamba` arch:", e)
pass
@ -80,7 +81,7 @@ SELECTED_ARCH = cfg.model.arch_type
if SELECTED_ARCH not in AVAILABLE_ARCHES:
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available")
if SELECTED_ARCH == "mamba":
if SELECTED_ARCH in ["mamba","mamba2"]:
LlmArchClass = MambaLMHeadModel
elif SELECTED_ARCH == "llama":
LlmArchClass = LlamaForCausalLM
@ -103,7 +104,8 @@ class Model(LlmArchClass):
hf_attention = config.attention if config is not None else None
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
vocab_size = 256 + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
vocab_size = 256 + cfg.model.max_levels + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
if SELECTED_ARCH == "llama":
super().__init__(config=LlamaConfig(
@ -148,12 +150,12 @@ class Model(LlmArchClass):
decoder_normalize_before=True,
))
elif SELECTED_ARCH == "mamba":
elif SELECTED_ARCH in ["mamba","mamba2"]:
super().__init__(config=MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layers*2,
#ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if SELECTED_ARCH == "mamba2" else {},
))
self.backbone.gradient_checkpointing = gradient_checkpointing
@ -163,7 +165,7 @@ class Model(LlmArchClass):
*args,
**kwargs
):
if SELECTED_ARCH == "mamba":
if SELECTED_ARCH in ["mamba","mamba2"]:
kwargs["cg"] = True
if "attention_mask" in kwargs:
@ -182,7 +184,7 @@ class Model(LlmArchClass):
*args,
**kwargs,
):
if SELECTED_ARCH == "mamba":
if SELECTED_ARCH in ["mamba","mamba2"]:
if "attention_mask" in kwargs:
kwargs.pop("attention_mask")
@ -193,7 +195,7 @@ class Model(LlmArchClass):
self.loss = dict(
nll = output.loss,
)
elif SELECTED_ARCH == "mamba":
elif SELECTED_ARCH in ["mamba","mamba2"]:
if "labels" in kwargs:
labels = kwargs.pop("labels")
logits = output.logits
@ -262,38 +264,15 @@ def example_usage():
prom_list = prom_list[:1]
resp_list = resp_list[:1]
if False:
output_list = [ [] ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[0])
unfolded = unfold_outputs( input_ids, quant_levels=[0])
print( 0, "inputs:", input_ids.shape, input_ids )
print( 0, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 0] )
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[1])
unfolded = unfold_outputs( input_ids, quant_levels=[1])
print( 1, "inputs:", input_ids.shape, input_ids )
print( 1, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 1] )
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[2])
unfolded = unfold_outputs( input_ids, quant_levels=[2])
print( 2, "inputs:", input_ids.shape, input_ids )
print( 2, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 2] )
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[3])
unfolded = unfold_outputs( input_ids, quant_levels=[3])
print( 3, "inputs:", input_ids.shape, input_ids )
print( 3, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 3] )
return
kwargs = {}
model = Model(**kwargs).to(device)
steps = 50 if cfg.model.interleave else 250
steps = 100
if cfg.model.arch_type == "mamba2":
steps = 100
elif cfg.model.arch_type == "llama":
steps = 500
elif cfg.model.interleave:
steps = 250
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""