tweaks
This commit is contained in:
parent
6d5bd0156a
commit
014e565c4b
|
@ -76,6 +76,14 @@ def fold_inputs(
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
offset = text_tokens
|
offset = text_tokens
|
||||||
|
# inject target quant_level
|
||||||
|
if quant_levels is not None:
|
||||||
|
for i, rvq in enumerate( quant_levels ):
|
||||||
|
seq = torch.Tensor([offset + rvq]).to("cpu", dtype=torch.int64)
|
||||||
|
input_ids[i].append( seq )
|
||||||
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
|
offset = text_tokens + audio_rvq_levels
|
||||||
for i, prom in enumerate(prom_list):
|
for i, prom in enumerate(prom_list):
|
||||||
# deinterleaved
|
# deinterleaved
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
|
@ -98,7 +106,7 @@ def fold_inputs(
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
offset = text_tokens + (audio_tokens * audio_rvq_levels)
|
offset = text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels)
|
||||||
|
|
||||||
for i, resp in enumerate(resp_list):
|
for i, resp in enumerate(resp_list):
|
||||||
# deinterleaved
|
# deinterleaved
|
||||||
|
@ -107,7 +115,10 @@ def fold_inputs(
|
||||||
quant_level = quant_levels[i] - 1
|
quant_level = quant_levels[i] - 1
|
||||||
# way to signal we want to inference for rvq level 0
|
# way to signal we want to inference for rvq level 0
|
||||||
# without it, it's a random chance for any level to be selected again
|
# without it, it's a random chance for any level to be selected again
|
||||||
|
|
||||||
if quant_level < 0:
|
if quant_level < 0:
|
||||||
|
continue
|
||||||
|
|
||||||
seq = sep
|
seq = sep
|
||||||
else:
|
else:
|
||||||
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
|
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
|
||||||
|
@ -192,10 +203,10 @@ def unfold_outputs(
|
||||||
|
|
||||||
if 0 <= id and id < text_tokens:
|
if 0 <= id and id < text_tokens:
|
||||||
text_list[i].append( id )
|
text_list[i].append( id )
|
||||||
elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels):
|
elif text_tokens + audio_rvq_levels <= id and id < text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels):
|
||||||
prom_list[i].append( (id - text_tokens) % audio_tokens )
|
prom_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
|
||||||
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
|
elif text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) <= id:
|
||||||
resp_list[i].append( (id - text_tokens) % audio_tokens )
|
resp_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens )
|
||||||
if not flushed:
|
if not flushed:
|
||||||
should_flush = True
|
should_flush = True
|
||||||
|
|
||||||
|
|
|
@ -71,6 +71,7 @@ try:
|
||||||
MambaMixelModel.forward = MambaMixelModel_forward
|
MambaMixelModel.forward = MambaMixelModel_forward
|
||||||
|
|
||||||
AVAILABLE_ARCHES.append("mamba")
|
AVAILABLE_ARCHES.append("mamba")
|
||||||
|
AVAILABLE_ARCHES.append("mamba2")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error importing `mamba` arch:", e)
|
print("Error importing `mamba` arch:", e)
|
||||||
pass
|
pass
|
||||||
|
@ -80,7 +81,7 @@ SELECTED_ARCH = cfg.model.arch_type
|
||||||
if SELECTED_ARCH not in AVAILABLE_ARCHES:
|
if SELECTED_ARCH not in AVAILABLE_ARCHES:
|
||||||
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available")
|
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available")
|
||||||
|
|
||||||
if SELECTED_ARCH == "mamba":
|
if SELECTED_ARCH in ["mamba","mamba2"]:
|
||||||
LlmArchClass = MambaLMHeadModel
|
LlmArchClass = MambaLMHeadModel
|
||||||
elif SELECTED_ARCH == "llama":
|
elif SELECTED_ARCH == "llama":
|
||||||
LlmArchClass = LlamaForCausalLM
|
LlmArchClass = LlamaForCausalLM
|
||||||
|
@ -103,7 +104,8 @@ class Model(LlmArchClass):
|
||||||
|
|
||||||
hf_attention = config.attention if config is not None else None
|
hf_attention = config.attention if config is not None else None
|
||||||
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
||||||
vocab_size = 256 + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
|
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
||||||
|
vocab_size = 256 + cfg.model.max_levels + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
|
||||||
|
|
||||||
if SELECTED_ARCH == "llama":
|
if SELECTED_ARCH == "llama":
|
||||||
super().__init__(config=LlamaConfig(
|
super().__init__(config=LlamaConfig(
|
||||||
|
@ -148,12 +150,12 @@ class Model(LlmArchClass):
|
||||||
|
|
||||||
decoder_normalize_before=True,
|
decoder_normalize_before=True,
|
||||||
))
|
))
|
||||||
elif SELECTED_ARCH == "mamba":
|
elif SELECTED_ARCH in ["mamba","mamba2"]:
|
||||||
super().__init__(config=MambaConfig(
|
super().__init__(config=MambaConfig(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
n_layer=n_layers*2,
|
n_layer=n_layers*2,
|
||||||
#ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan
|
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if SELECTED_ARCH == "mamba2" else {},
|
||||||
))
|
))
|
||||||
|
|
||||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||||
|
@ -163,7 +165,7 @@ class Model(LlmArchClass):
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if SELECTED_ARCH == "mamba":
|
if SELECTED_ARCH in ["mamba","mamba2"]:
|
||||||
kwargs["cg"] = True
|
kwargs["cg"] = True
|
||||||
|
|
||||||
if "attention_mask" in kwargs:
|
if "attention_mask" in kwargs:
|
||||||
|
@ -182,7 +184,7 @@ class Model(LlmArchClass):
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if SELECTED_ARCH == "mamba":
|
if SELECTED_ARCH in ["mamba","mamba2"]:
|
||||||
if "attention_mask" in kwargs:
|
if "attention_mask" in kwargs:
|
||||||
kwargs.pop("attention_mask")
|
kwargs.pop("attention_mask")
|
||||||
|
|
||||||
|
@ -193,7 +195,7 @@ class Model(LlmArchClass):
|
||||||
self.loss = dict(
|
self.loss = dict(
|
||||||
nll = output.loss,
|
nll = output.loss,
|
||||||
)
|
)
|
||||||
elif SELECTED_ARCH == "mamba":
|
elif SELECTED_ARCH in ["mamba","mamba2"]:
|
||||||
if "labels" in kwargs:
|
if "labels" in kwargs:
|
||||||
labels = kwargs.pop("labels")
|
labels = kwargs.pop("labels")
|
||||||
logits = output.logits
|
logits = output.logits
|
||||||
|
@ -262,38 +264,15 @@ def example_usage():
|
||||||
prom_list = prom_list[:1]
|
prom_list = prom_list[:1]
|
||||||
resp_list = resp_list[:1]
|
resp_list = resp_list[:1]
|
||||||
|
|
||||||
if False:
|
|
||||||
output_list = [ [] ]
|
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[0])
|
|
||||||
unfolded = unfold_outputs( input_ids, quant_levels=[0])
|
|
||||||
print( 0, "inputs:", input_ids.shape, input_ids )
|
|
||||||
print( 0, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
|
||||||
output_list[0].append( resp_list[0][:, 0] )
|
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[1])
|
|
||||||
unfolded = unfold_outputs( input_ids, quant_levels=[1])
|
|
||||||
print( 1, "inputs:", input_ids.shape, input_ids )
|
|
||||||
print( 1, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
|
||||||
output_list[0].append( resp_list[0][:, 1] )
|
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[2])
|
|
||||||
unfolded = unfold_outputs( input_ids, quant_levels=[2])
|
|
||||||
print( 2, "inputs:", input_ids.shape, input_ids )
|
|
||||||
print( 2, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
|
||||||
output_list[0].append( resp_list[0][:, 2] )
|
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[3])
|
|
||||||
unfolded = unfold_outputs( input_ids, quant_levels=[3])
|
|
||||||
print( 3, "inputs:", input_ids.shape, input_ids )
|
|
||||||
print( 3, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
|
||||||
output_list[0].append( resp_list[0][:, 3] )
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
model = Model(**kwargs).to(device)
|
model = Model(**kwargs).to(device)
|
||||||
steps = 50 if cfg.model.interleave else 250
|
steps = 100
|
||||||
|
if cfg.model.arch_type == "mamba2":
|
||||||
|
steps = 100
|
||||||
|
elif cfg.model.arch_type == "llama":
|
||||||
|
steps = 500
|
||||||
|
elif cfg.model.interleave:
|
||||||
|
steps = 250
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user