ughghghhhh
This commit is contained in:
parent
ed373957e2
commit
2a1794c084
|
@ -58,9 +58,11 @@ def fold_inputs(
|
|||
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
|
||||
return (seq < stop).float() # (b t)
|
||||
|
||||
def list_to_tensor(x_list: list[Tensor]):
|
||||
def list_to_tensor(x_list: list[Tensor], mask=True):
|
||||
l = list(map(len, x_list))
|
||||
x = pad_sequence(x_list).t()
|
||||
if not mask:
|
||||
return x
|
||||
|
||||
m = _create_mask(l, x_list[0].device)
|
||||
m = m.to(x)
|
||||
|
@ -68,7 +70,7 @@ def fold_inputs(
|
|||
|
||||
def process_prom_or_task(i, prom):
|
||||
if prom is None:
|
||||
return
|
||||
return 0
|
||||
|
||||
if isinstance(prom, str):
|
||||
task = get_task_symmap()[f'<{input}>']
|
||||
|
@ -76,7 +78,8 @@ def fold_inputs(
|
|||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
return
|
||||
|
||||
return seq.shape[0] + 1
|
||||
|
||||
# deinterleaved
|
||||
if quant_levels is not None:
|
||||
|
@ -99,6 +102,11 @@ def fold_inputs(
|
|||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
return seq.shape[0] + 1
|
||||
|
||||
def generate_position_ids( length, sep=True ):
|
||||
return [ i for i in range( length + (1 if sep else 0) ) ]
|
||||
|
||||
"""
|
||||
if quant_levels is not None:
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
||||
|
@ -109,6 +117,7 @@ def fold_inputs(
|
|||
|
||||
batch_size = len(text_list)
|
||||
input_ids = [ [] for _ in range(batch_size) ]
|
||||
position_ids = [ [] for _ in range(batch_size) ]
|
||||
|
||||
offset = 0
|
||||
|
||||
|
@ -142,17 +151,23 @@ def fold_inputs(
|
|||
seq = text + text_start
|
||||
else:
|
||||
seq = torch.tensor([text_start + text], device=device, dtype=dtype)
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||
|
||||
# lang tokens
|
||||
for i, lang in enumerate(lang_list):
|
||||
if isinstance(lang, torch.Tensor):
|
||||
seq = lang + lang_start
|
||||
else:
|
||||
seq = torch.tensor([lang_start + lang], device=device, dtype=dtype)
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||
|
||||
# inject target quant_level
|
||||
if quant_levels is not None:
|
||||
|
@ -164,15 +179,20 @@ def fold_inputs(
|
|||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||
|
||||
# prom / task tokens
|
||||
for i, prom in enumerate(prom_list):
|
||||
# list of proms with a possible task token
|
||||
length = 0
|
||||
if isinstance(prom, list):
|
||||
for p in prom:
|
||||
process_prom_or_task(i, p)
|
||||
length += process_prom_or_task(i, p)
|
||||
# raw tensor
|
||||
else:
|
||||
process_prom_or_task(i, prom)
|
||||
length += process_prom_or_task(i, prom)
|
||||
|
||||
position_ids[i].append( generate_position_ids( length, sep=False ) )
|
||||
|
||||
# tone tokens
|
||||
for i, tone in enumerate(tone_list):
|
||||
|
@ -183,6 +203,8 @@ def fold_inputs(
|
|||
input_ids[i].append( seq )
|
||||
input_ids[i].append( sep )
|
||||
|
||||
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||
|
||||
# resp tokens
|
||||
for i, resp in enumerate(resp_list):
|
||||
# deinterleaved
|
||||
|
@ -205,6 +227,8 @@ def fold_inputs(
|
|||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
|
||||
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||
# interleaved
|
||||
else:
|
||||
seq = resp.flatten().to(device=device, dtype=dtype)
|
||||
|
@ -213,6 +237,8 @@ def fold_inputs(
|
|||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
|
||||
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||
|
||||
# targ list
|
||||
for i, resp in enumerate(targ_list):
|
||||
|
@ -225,6 +251,8 @@ def fold_inputs(
|
|||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
|
||||
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||
# interleaved
|
||||
else:
|
||||
seq = resp.flatten().to(device=device, dtype=dtype)
|
||||
|
@ -233,11 +261,17 @@ def fold_inputs(
|
|||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
|
||||
position_ids[i].append( generate_position_ids( seq.shape[0] ) )
|
||||
|
||||
for i, batch in enumerate(input_ids):
|
||||
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=dtype)
|
||||
position_ids[i] = torch.concat([ torch.tensor(ids, device=device, dtype=dtype) for ids in position_ids[i] ], dim=-1)
|
||||
|
||||
return list_to_tensor(input_ids)
|
||||
input_ids, attention_mask = list_to_tensor(input_ids)
|
||||
position_ids = list_to_tensor(position_ids, mask=False)
|
||||
|
||||
return input_ids, attention_mask, position_ids
|
||||
|
||||
# unfold from one unified token ID space to separate token spaces
|
||||
# to-do: unfold at a specific RVQ level instead if requested
|
||||
|
|
|
@ -316,7 +316,7 @@ class AR_NAR(Base):
|
|||
|
||||
|
||||
def example_usage():
|
||||
# cfg.trainer.backend = "local"
|
||||
cfg.trainer.backend = "local"
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
if cfg.audio_backend == "dac":
|
||||
cfg.sample_rate = 44_100
|
||||
|
@ -398,7 +398,7 @@ def example_usage():
|
|||
tasks = cfg.dataset.tasks_list
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 150 * len(tasks) * cfg.model.experimental.causal_size
|
||||
steps = 150 * len(tasks) # * cfg.model.experimental.causal_size
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
|
@ -9,9 +9,9 @@ def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_
|
|||
residual = None
|
||||
for layer in self.layers:
|
||||
if self.gradient_checkpointing and hidden_states.requires_grad:
|
||||
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False )
|
||||
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, **mixer_kwargs, use_reentrant=False )
|
||||
else:
|
||||
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params )
|
||||
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params, **mixer_kwargs )
|
||||
if not self.fused_add_norm:
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
||||
|
|
|
@ -701,12 +701,12 @@ class Base(nn.Module):
|
|||
self.model = MambaMixelModel(
|
||||
vocab_size=n_resp_tokens,
|
||||
d_model=d_model,
|
||||
n_layer=n_layers,
|
||||
d_intermediate=d_model*4,
|
||||
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": False} if self.arch_type == "mamba2" else {},
|
||||
n_layer=n_layers*2,
|
||||
d_intermediate=0, #d_model*2,
|
||||
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if self.arch_type == "mamba2" else {},
|
||||
rms_norm=True,
|
||||
fused_add_norm=True,
|
||||
residual_in_fp32=False,
|
||||
residual_in_fp32=True,
|
||||
#attn_layer_idx=attn_layer_idx,
|
||||
#attn_cfg=attn_cfg,
|
||||
#initializer_cfg=initializer_cfg,
|
||||
|
@ -722,7 +722,7 @@ class Base(nn.Module):
|
|||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it)
|
||||
residual_in_fp32=False, # breaks for AMP inference
|
||||
residual_in_fp32=True, # breaks for AMP inference
|
||||
))
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
|
|
|
@ -59,6 +59,15 @@ class Model(LlmArchClass):
|
|||
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
||||
# vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
|
||||
|
||||
if hf_attention == "auto":
|
||||
if AVAILABLE_ATTENTIONS:
|
||||
hf_attention = AVAILABLE_ATTENTIONS[0]
|
||||
else:
|
||||
hf_attention = "eager"
|
||||
|
||||
if hf_attention == "xformers":
|
||||
hf_attention = "mem_efficient"
|
||||
|
||||
text_start = 0
|
||||
text_end = text_start + config.text_tokens
|
||||
|
||||
|
@ -82,17 +91,17 @@ class Model(LlmArchClass):
|
|||
|
||||
vocab_size = resp_end
|
||||
|
||||
if cfg.model.arch_type == "llama":
|
||||
if config.arch_type == "llama":
|
||||
super().__init__(config=LlamaConfig(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.max_levels * 60, # max-length of 60 seconds
|
||||
max_position_embeddings=cfg.dataset.frames_per_second * config.max_levels * 60, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout,
|
||||
num_key_value_heads=n_heads,
|
||||
sliding_window=cfg.dataset.frames_per_second * cfg.model.max_levels * 12,
|
||||
sliding_window=cfg.dataset.frames_per_second * config.max_levels * 12,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
|
@ -103,7 +112,7 @@ class Model(LlmArchClass):
|
|||
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
elif cfg.model.arch_type == "retnet":
|
||||
elif config.arch_type == "retnet":
|
||||
super().__init__(config=RetNetConfig(
|
||||
vocab_size=vocab_size,
|
||||
decoder_embed_dim=d_model,
|
||||
|
@ -125,16 +134,16 @@ class Model(LlmArchClass):
|
|||
|
||||
decoder_normalize_before=True,
|
||||
))
|
||||
elif cfg.model.arch_type in ["mamba","mamba2"]:
|
||||
elif config.arch_type in ["mamba","mamba2"]:
|
||||
super().__init__(config=MambaConfig(
|
||||
vocab_size=vocab_size,
|
||||
d_model=d_model,
|
||||
n_layer=n_layers,
|
||||
d_intermediate=d_model*4,
|
||||
ssm_cfg={"layer": "Mamba2"} if cfg.model.arch_type == "mamba2" else {},
|
||||
n_layer=n_layers*2,
|
||||
d_intermediate=0, # d_model*4,
|
||||
ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if config.arch_type == "mamba2" else {},
|
||||
rms_norm=True,
|
||||
fused_add_norm=True,
|
||||
residual_in_fp32=True,
|
||||
residual_in_fp32=False,
|
||||
))
|
||||
|
||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||
|
@ -163,8 +172,87 @@ class Model(LlmArchClass):
|
|||
|
||||
if "min_length" in kwargs:
|
||||
kwargs.pop("min_length")
|
||||
|
||||
"""
|
||||
if "position_ids" in kwargs:
|
||||
kwargs.pop("position_ids")
|
||||
|
||||
if "max_new_tokens" in kwargs:
|
||||
kwargs.pop("max_new_tokens")
|
||||
|
||||
if "max_length" not in kwargs:
|
||||
kwargs["max_length"] = 500 * (self.hyper_config.resp_levels if self.hyper_config.experimental.interleave else 1)
|
||||
|
||||
if "num_last_tokens" not in kwargs:
|
||||
kwargs["num_last_tokens"] = self.hyper_config.experimental.causal_size
|
||||
"""
|
||||
|
||||
input_ids = kwargs.pop("input_ids")
|
||||
attention_mask = kwargs.pop("attention_mask", None)
|
||||
position_ids = kwargs.pop("position_ids", None)
|
||||
|
||||
stop_token = kwargs.pop("eos_token_id", 3)
|
||||
max_steps = kwargs.pop("max_new_tokens", 500)
|
||||
|
||||
device = input_ids.device
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
sequence_list = [ inputs for inputs in input_ids ]
|
||||
position_list = [ positions for positions in position_ids ]
|
||||
|
||||
start_positions = [ inputs.shape[0] for inputs in input_ids ]
|
||||
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
config = self.hyper_config
|
||||
state = None
|
||||
disable_tqdm = False
|
||||
causal_size = config.experimental.causal_size
|
||||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, causal_size), desc="AR", disable=disable_tqdm):
|
||||
output = super().forward(
|
||||
input_ids=torch.stack(sequence_list),
|
||||
#attention_mask=attention_mask,
|
||||
#past_key_values=state,
|
||||
#position_ids=torch.stack(position_list),
|
||||
#use_cache=False,
|
||||
#return_dict=False
|
||||
)
|
||||
|
||||
logits = output[0]
|
||||
# state = output[1]
|
||||
|
||||
r = [ logit[-causal_size:].argmax(dim=1) for logit in logits ]
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
if stop_token in ri:
|
||||
stopped[i] = True
|
||||
|
||||
last_position_id = position_list[i][-1].item() + 1
|
||||
sequence_list[i] = torch.cat([ sequence_list[i], ri.to(device) ], dim=0)
|
||||
#position_list[i] = torch.cat([ position_list[i], torch.tensor([ last_position_id + _ for _ in range( ri.shape[0] ) ], device=device, dtype=torch.int32) ])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
def _prune(l: Tensor, stop = stop_token):
|
||||
indices = (l == stop).nonzero()
|
||||
|
||||
if len(indices) == 0:
|
||||
return l
|
||||
|
||||
return l[: indices.min().item()]
|
||||
|
||||
sequence_list = [ _prune(seq[start_positions[i]:], stop_token) for i, seq in enumerate(sequence_list) ]
|
||||
return torch.stack(sequence_list)
|
||||
|
||||
"""
|
||||
return super().generate(*args, **kwargs)
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -188,14 +276,14 @@ class Model(LlmArchClass):
|
|||
if training:
|
||||
quant_levels = None if config.experimental.interleave else [ random.randint( 0 if "ar" in config.capabilities else 1, config.max_levels - 1) for _ in range(batch_size) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
input_ids, attention_mask, position_ids = fold_inputs(
|
||||
text_list=text_list,
|
||||
prom_list=proms_list,
|
||||
resp_list=resps_list,
|
||||
targ_list=resps_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
target_ids, target_attention_mask = fold_inputs(
|
||||
target_ids, target_attention_mask, target_position_ids = fold_inputs(
|
||||
text_list=text_list,
|
||||
prom_list=proms_list,
|
||||
resp_list=resps_list,
|
||||
|
@ -206,14 +294,16 @@ class Model(LlmArchClass):
|
|||
return self.forward(
|
||||
input_ids=input_ids,
|
||||
labels=target_ids,
|
||||
position_ids=position_ids,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
if config.experimental.interleave:
|
||||
input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list )
|
||||
input_ids, attention_mask, position_ids = fold_inputs( text_list=text_list, prom_list=proms_list )
|
||||
output = self.generate(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
eos_token_id=3,
|
||||
do_sample=True,
|
||||
|
@ -225,7 +315,7 @@ class Model(LlmArchClass):
|
|||
for l in range(config.max_levels):
|
||||
quant_levels = [ l for _ in range(batch_size) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels)
|
||||
input_ids, attention_mask, position_ids = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels)
|
||||
min_length = 1
|
||||
for batch in input_ids:
|
||||
min_length = max( min_length, batch.shape[0] + 1 )
|
||||
|
@ -234,6 +324,7 @@ class Model(LlmArchClass):
|
|||
output = self.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
eos_token_id=3,
|
||||
do_sample=True,
|
||||
max_new_tokens=steps,
|
||||
|
@ -273,10 +364,13 @@ class Model(LlmArchClass):
|
|||
|
||||
# i HATE the correct way
|
||||
if labels is not None:
|
||||
if quant_levels is None:
|
||||
quant_levels = [0 for _ in range(labels.shape[0])]
|
||||
|
||||
# predict the next token for AR, else predict in place
|
||||
loss = sum([ F.cross_entropy(
|
||||
logit[:-1, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
|
||||
label[1:] if quant_level == 0 or "nar" not in config.capabilities else label,
|
||||
logit[:-config.experimental.causal_size, :] if quant_level == 0 or "nar" not in config.capabilities else logit,
|
||||
label[config.experimental.causal_size:] if quant_level == 0 or "nar" not in config.capabilities else label,
|
||||
ignore_index=-100
|
||||
) for logit, label, quant_level in zip( logits, labels, quant_levels ) ])
|
||||
|
||||
|
@ -372,7 +466,7 @@ def example_usage():
|
|||
|
||||
kwargs = {}
|
||||
model = Model(**kwargs).to(device)
|
||||
steps = 100 if cfg.model.experimental.interleave else 300
|
||||
steps = 50 # 100 if cfg.model.experimental.interleave else 300
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
Loading…
Reference in New Issue
Block a user