ughghghhhh

This commit is contained in:
mrq 2024-08-09 21:15:01 -05:00
parent ed373957e2
commit 2a1794c084
5 changed files with 159 additions and 31 deletions

View File

@ -58,9 +58,11 @@ def fold_inputs(
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def list_to_tensor(x_list: list[Tensor]):
def list_to_tensor(x_list: list[Tensor], mask=True):
l = list(map(len, x_list))
x = pad_sequence(x_list).t()
if not mask:
return x
m = _create_mask(l, x_list[0].device)
m = m.to(x)
@ -68,7 +70,7 @@ def fold_inputs(
def process_prom_or_task(i, prom):
if prom is None:
return
return 0
if isinstance(prom, str):
task = get_task_symmap()[f'<{input}>']
@ -76,7 +78,8 @@ def fold_inputs(
input_ids[i].append( seq )
input_ids[i].append( sep )
return
return seq.shape[0] + 1
# deinterleaved
if quant_levels is not None:
@ -99,6 +102,11 @@ def fold_inputs(
input_ids[i].append( seq )
input_ids[i].append( sep )
return seq.shape[0] + 1
def generate_position_ids( length, sep=True ):
return [ i for i in range( length + (1 if sep else 0) ) ]
"""
if quant_levels is not None:
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
@ -109,6 +117,7 @@ def fold_inputs(
batch_size = len(text_list)
input_ids = [ [] for _ in range(batch_size) ]
position_ids = [ [] for _ in range(batch_size) ]
offset = 0
@ -142,17 +151,23 @@ def fold_inputs(
seq = text + text_start
else:
seq = torch.tensor([text_start + text], device=device, dtype=dtype)
input_ids[i].append( seq )
input_ids[i].append( sep )
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
# lang tokens
for i, lang in enumerate(lang_list):
if isinstance(lang, torch.Tensor):
seq = lang + lang_start
else:
seq = torch.tensor([lang_start + lang], device=device, dtype=dtype)
input_ids[i].append( seq )
input_ids[i].append( sep )
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
# inject target quant_level
if quant_levels is not None:
@ -164,15 +179,20 @@ def fold_inputs(
input_ids[i].append( seq )
input_ids[i].append( sep )
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
# prom / task tokens
for i, prom in enumerate(prom_list):
# list of proms with a possible task token
length = 0
if isinstance(prom, list):
for p in prom:
process_prom_or_task(i, p)
length += process_prom_or_task(i, p)
# raw tensor
else:
process_prom_or_task(i, prom)
length += process_prom_or_task(i, prom)
position_ids[i].append( generate_position_ids( length, sep=False ) )
# tone tokens
for i, tone in enumerate(tone_list):
@ -183,6 +203,8 @@ def fold_inputs(
input_ids[i].append( seq )
input_ids[i].append( sep )
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
# resp tokens
for i, resp in enumerate(resp_list):
# deinterleaved
@ -205,6 +227,8 @@ def fold_inputs(
input_ids[i].append( seq )
input_ids[i].append( stop )
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
# interleaved
else:
seq = resp.flatten().to(device=device, dtype=dtype)
@ -213,6 +237,8 @@ def fold_inputs(
input_ids[i].append( seq )
input_ids[i].append( stop )
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
# targ list
for i, resp in enumerate(targ_list):
@ -225,6 +251,8 @@ def fold_inputs(
input_ids[i].append( seq )
input_ids[i].append( stop )
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
# interleaved
else:
seq = resp.flatten().to(device=device, dtype=dtype)
@ -233,11 +261,17 @@ def fold_inputs(
input_ids[i].append( seq )
input_ids[i].append( stop )
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
for i, batch in enumerate(input_ids):
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=dtype)
position_ids[i] = torch.concat([ torch.tensor(ids, device=device, dtype=dtype) for ids in position_ids[i] ], dim=-1)
return list_to_tensor(input_ids)
input_ids, attention_mask = list_to_tensor(input_ids)
position_ids = list_to_tensor(position_ids, mask=False)
return input_ids, attention_mask, position_ids
# unfold from one unified token ID space to separate token spaces
# to-do: unfold at a specific RVQ level instead if requested

View File

@ -316,7 +316,7 @@ class AR_NAR(Base):
def example_usage():
# cfg.trainer.backend = "local"
cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_100
@ -398,7 +398,7 @@ def example_usage():
tasks = cfg.dataset.tasks_list
model = AR_NAR(**kwargs).to(device)
steps = 150 * len(tasks) * cfg.model.experimental.causal_size
steps = 150 * len(tasks) # * cfg.model.experimental.causal_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""

View File

@ -9,9 +9,9 @@ def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_
residual = None
for layer in self.layers:
if self.gradient_checkpointing and hidden_states.requires_grad:
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False )
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, **mixer_kwargs, use_reentrant=False )
else:
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params )
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params, **mixer_kwargs )
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))

View File

@ -701,12 +701,12 @@ class Base(nn.Module):
self.model = MambaMixelModel(
vocab_size=n_resp_tokens,
d_model=d_model,
n_layer=n_layers,
d_intermediate=d_model*4,
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": False} if self.arch_type == "mamba2" else {},
n_layer=n_layers*2,
d_intermediate=0, #d_model*2,
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if self.arch_type == "mamba2" else {},
rms_norm=True,
fused_add_norm=True,
residual_in_fp32=False,
residual_in_fp32=True,
#attn_layer_idx=attn_layer_idx,
#attn_cfg=attn_cfg,
#initializer_cfg=initializer_cfg,
@ -722,7 +722,7 @@ class Base(nn.Module):
is_encoder_decoder=False,
is_decoder=True,
use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it)
residual_in_fp32=False, # breaks for AMP inference
residual_in_fp32=True, # breaks for AMP inference
))
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(

View File

@ -59,6 +59,15 @@ class Model(LlmArchClass):
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
# vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
if hf_attention == "auto":
if AVAILABLE_ATTENTIONS:
hf_attention = AVAILABLE_ATTENTIONS[0]
else:
hf_attention = "eager"
if hf_attention == "xformers":
hf_attention = "mem_efficient"
text_start = 0
text_end = text_start + config.text_tokens
@ -82,17 +91,17 @@ class Model(LlmArchClass):
vocab_size = resp_end
if cfg.model.arch_type == "llama":
if config.arch_type == "llama":
super().__init__(config=LlamaConfig(
vocab_size=vocab_size,
hidden_size=d_model,
max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.max_levels * 60, # max-length of 60 seconds
max_position_embeddings=cfg.dataset.frames_per_second * config.max_levels * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout,
num_key_value_heads=n_heads,
sliding_window=cfg.dataset.frames_per_second * cfg.model.max_levels * 12,
sliding_window=cfg.dataset.frames_per_second * config.max_levels * 12,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
@ -103,7 +112,7 @@ class Model(LlmArchClass):
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
elif cfg.model.arch_type == "retnet":
elif config.arch_type == "retnet":
super().__init__(config=RetNetConfig(
vocab_size=vocab_size,
decoder_embed_dim=d_model,
@ -125,16 +134,16 @@ class Model(LlmArchClass):
decoder_normalize_before=True,
))
elif cfg.model.arch_type in ["mamba","mamba2"]:
elif config.arch_type in ["mamba","mamba2"]:
super().__init__(config=MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layers,
d_intermediate=d_model*4,
ssm_cfg={"layer": "Mamba2"} if cfg.model.arch_type == "mamba2" else {},
n_layer=n_layers*2,
d_intermediate=0, # d_model*4,
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if config.arch_type == "mamba2" else {},
rms_norm=True,
fused_add_norm=True,
residual_in_fp32=True,
residual_in_fp32=False,
))
self.backbone.gradient_checkpointing = gradient_checkpointing
@ -163,8 +172,87 @@ class Model(LlmArchClass):
if "min_length" in kwargs:
kwargs.pop("min_length")
"""
if "position_ids" in kwargs:
kwargs.pop("position_ids")
if "max_new_tokens" in kwargs:
kwargs.pop("max_new_tokens")
if "max_length" not in kwargs:
kwargs["max_length"] = 500 * (self.hyper_config.resp_levels if self.hyper_config.experimental.interleave else 1)
if "num_last_tokens" not in kwargs:
kwargs["num_last_tokens"] = self.hyper_config.experimental.causal_size
"""
input_ids = kwargs.pop("input_ids")
attention_mask = kwargs.pop("attention_mask", None)
position_ids = kwargs.pop("position_ids", None)
stop_token = kwargs.pop("eos_token_id", 3)
max_steps = kwargs.pop("max_new_tokens", 500)
device = input_ids.device
batch_size = input_ids.shape[0]
sequence_list = [ inputs for inputs in input_ids ]
position_list = [ positions for positions in position_ids ]
start_positions = [ inputs.shape[0] for inputs in input_ids ]
stopped = torch.zeros(batch_size, device=device).bool()
config = self.hyper_config
state = None
disable_tqdm = False
causal_size = config.experimental.causal_size
# get next in sequence
for n in trange(max_steps // max(1, causal_size), desc="AR", disable=disable_tqdm):
output = super().forward(
input_ids=torch.stack(sequence_list),
#attention_mask=attention_mask,
#past_key_values=state,
#position_ids=torch.stack(position_list),
#use_cache=False,
#return_dict=False
)
logits = output[0]
# state = output[1]
r = [ logit[-causal_size:].argmax(dim=1) for logit in logits ]
# append tokens
for i, ri in enumerate(r):
if stop_token in ri:
stopped[i] = True
last_position_id = position_list[i][-1].item() + 1
sequence_list[i] = torch.cat([ sequence_list[i], ri.to(device) ], dim=0)
#position_list[i] = torch.cat([ position_list[i], torch.tensor([ last_position_id + _ for _ in range( ri.shape[0] ) ], device=device, dtype=torch.int32) ])
# stop token found
stopped |= r == stop_token
if stopped.all().item():
break
def _prune(l: Tensor, stop = stop_token):
indices = (l == stop).nonzero()
if len(indices) == 0:
return l
return l[: indices.min().item()]
sequence_list = [ _prune(seq[start_positions[i]:], stop_token) for i, seq in enumerate(sequence_list) ]
return torch.stack(sequence_list)
"""
return super().generate(*args, **kwargs)
"""
def forward(
self,
@ -188,14 +276,14 @@ class Model(LlmArchClass):
if training:
quant_levels = None if config.experimental.interleave else [ random.randint( 0 if "ar" in config.capabilities else 1, config.max_levels - 1) for _ in range(batch_size) ]
input_ids, attention_mask = fold_inputs(
input_ids, attention_mask, position_ids = fold_inputs(
text_list=text_list,
prom_list=proms_list,
resp_list=resps_list,
targ_list=resps_list,
quant_levels=quant_levels,
)
target_ids, target_attention_mask = fold_inputs(
target_ids, target_attention_mask, target_position_ids = fold_inputs(
text_list=text_list,
prom_list=proms_list,
resp_list=resps_list,
@ -206,14 +294,16 @@ class Model(LlmArchClass):
return self.forward(
input_ids=input_ids,
labels=target_ids,
position_ids=position_ids,
quant_levels=quant_levels,
)
if config.experimental.interleave:
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
input_ids, attention_mask, position_ids = fold_inputs( text_list=text_list, prom_list=proms_list )
output = self.generate(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
eos_token_id=3,
do_sample=True,
@ -225,7 +315,7 @@ class Model(LlmArchClass):
for l in range(config.max_levels):
quant_levels = [ l for _ in range(batch_size) ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels)
input_ids, attention_mask, position_ids = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels)
min_length = 1
for batch in input_ids:
min_length = max( min_length, batch.shape[0] + 1 )
@ -234,6 +324,7 @@ class Model(LlmArchClass):
output = self.generate(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
eos_token_id=3,
do_sample=True,
max_new_tokens=steps,
@ -273,10 +364,13 @@ class Model(LlmArchClass):
# i HATE the correct way
if labels is not None:
if quant_levels is None:
quant_levels = [0 for _ in range(labels.shape[0])]
# predict the next token for AR, else predict in place
loss = sum([ F.cross_entropy(
logit[:-1, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
label[1:] if quant_level == 0 or "nar" not in config.capabilities else label,
logit[:-config.experimental.causal_size, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
label[config.experimental.causal_size:] if quant_level == 0 or "nar" not in config.capabilities else label,
ignore_index=-100
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
@ -372,7 +466,7 @@ def example_usage():
kwargs = {}
model = Model(**kwargs).to(device)
steps = 100 if cfg.model.experimental.interleave else 300
steps = 50 # 100 if cfg.model.experimental.interleave else 300
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""