re-implemented config.model.interleave for the HF-compat experimental method
This commit is contained in:
parent
c93d5863fd
commit
406ff7bbe1
|
@ -216,7 +216,7 @@ class Model:
|
||||||
dropout: float = 0.1 # adjustable dropout value
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
|
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
|
||||||
kv_heads: int = 0
|
kv_heads: int = 0
|
||||||
experimental: bool = False
|
experimental: bool = False # for now it sets things to be HF compatible
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
return [ self ] if not name or self.name == name else []
|
return [ self ] if not name or self.name == name else []
|
||||||
|
|
|
@ -44,7 +44,8 @@ def fold_inputs(
|
||||||
|
|
||||||
text_tokens = 256,
|
text_tokens = 256,
|
||||||
audio_tokens = 1024,
|
audio_tokens = 1024,
|
||||||
audio_rvq_levels = cfg.model.max_levels
|
audio_rvq_levels = cfg.model.max_levels,
|
||||||
|
quant_levels = None,
|
||||||
):
|
):
|
||||||
def _create_mask(l, device):
|
def _create_mask(l, device):
|
||||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||||
|
@ -75,23 +76,43 @@ def fold_inputs(
|
||||||
|
|
||||||
offset = text_tokens
|
offset = text_tokens
|
||||||
for i, prom in enumerate(prom_list):
|
for i, prom in enumerate(prom_list):
|
||||||
if ignore_index is not None:
|
if quant_levels is not None:
|
||||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
|
quant_level = quant_levels[i]
|
||||||
|
if ignore_index is not None:
|
||||||
|
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to("cpu", dtype=torch.int64)
|
||||||
|
else:
|
||||||
|
seq = prom[:, quant_level].to("cpu", dtype=torch.int64)
|
||||||
|
for idx, token in enumerate( seq ):
|
||||||
|
token += offset + ( audio_tokens * quant_level )
|
||||||
else:
|
else:
|
||||||
seq = prom.flatten().to("cpu", dtype=torch.int64)
|
if ignore_index is not None:
|
||||||
for idx, token in enumerate( seq ):
|
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
|
||||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
else:
|
||||||
|
seq = prom.flatten().to("cpu", dtype=torch.int64)
|
||||||
|
for idx, token in enumerate( seq ):
|
||||||
|
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
||||||
|
|
||||||
input_ids[i].append( seq )
|
input_ids[i].append( seq )
|
||||||
input_ids[i].append( sep )
|
input_ids[i].append( sep )
|
||||||
|
|
||||||
offset = text_tokens + (audio_tokens * audio_rvq_levels)
|
offset = text_tokens + (audio_tokens * audio_rvq_levels)
|
||||||
for i, resp in enumerate(resp_list):
|
for i, resp in enumerate(resp_list):
|
||||||
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
if quant_levels is not None:
|
||||||
for idx, token in enumerate( seq ):
|
quant_level = quant_levels[i]
|
||||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
|
||||||
input_ids[i].append( seq )
|
for idx, token in enumerate( seq ):
|
||||||
input_ids[i].append( stop )
|
token += offset + ( audio_tokens * quant_level )
|
||||||
|
|
||||||
|
input_ids[i].append( seq )
|
||||||
|
if quant_level == 0:
|
||||||
|
input_ids[i].append( stop )
|
||||||
|
else:
|
||||||
|
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
||||||
|
for idx, token in enumerate( seq ):
|
||||||
|
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
||||||
|
|
||||||
|
input_ids[i].append( seq )
|
||||||
|
input_ids[i].append( stop )
|
||||||
|
|
||||||
for i, batch in enumerate(input_ids):
|
for i, batch in enumerate(input_ids):
|
||||||
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
|
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
|
||||||
|
@ -99,6 +120,7 @@ def fold_inputs(
|
||||||
return list_to_tensor(input_ids)
|
return list_to_tensor(input_ids)
|
||||||
|
|
||||||
# unfold from one unified token ID space to separate token spaces
|
# unfold from one unified token ID space to separate token spaces
|
||||||
|
# to-do: unfold at a specific RVQ level instead if requested
|
||||||
def unfold_outputs(
|
def unfold_outputs(
|
||||||
output_ids,
|
output_ids,
|
||||||
|
|
||||||
|
@ -107,7 +129,8 @@ def unfold_outputs(
|
||||||
|
|
||||||
text_tokens = 256,
|
text_tokens = 256,
|
||||||
audio_tokens = 1024,
|
audio_tokens = 1024,
|
||||||
audio_rvq_levels = cfg.model.max_levels
|
audio_rvq_levels = cfg.model.max_levels,
|
||||||
|
quant_levels = None,
|
||||||
):
|
):
|
||||||
device = output_ids.device
|
device = output_ids.device
|
||||||
batch_size = output_ids.shape[0]
|
batch_size = output_ids.shape[0]
|
||||||
|
@ -129,30 +152,33 @@ def unfold_outputs(
|
||||||
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
|
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
|
||||||
resp_list[i].append( (id - text_tokens) % audio_tokens )
|
resp_list[i].append( (id - text_tokens) % audio_tokens )
|
||||||
|
|
||||||
prom_len = len(prom_list[i])
|
if quant_levels is not None:
|
||||||
if prom_len % audio_rvq_levels == 0 and False:
|
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64)
|
||||||
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
|
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=torch.int64)
|
||||||
else:
|
else:
|
||||||
bins = [ [] for _ in range(audio_rvq_levels) ]
|
prom_len = len(prom_list[i])
|
||||||
for pos in range( prom_len ):
|
if prom_len % audio_rvq_levels == 0 and False:
|
||||||
rvq = pos % audio_rvq_levels
|
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
|
||||||
bins[rvq].append( prom_list[i][pos] )
|
else:
|
||||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
bins = [ [] for _ in range(audio_rvq_levels) ]
|
||||||
bins = bins[:nearest]
|
for pos in range( prom_len ):
|
||||||
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
rvq = pos % audio_rvq_levels
|
||||||
|
bins[rvq].append( prom_list[i][pos] )
|
||||||
|
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||||
|
bins = bins[:nearest]
|
||||||
|
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||||
|
|
||||||
|
resp_len = len(resp_list[i])
|
||||||
resp_len = len(resp_list[i])
|
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
|
||||||
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
|
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
|
||||||
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
|
else:
|
||||||
else:
|
bins = [ [] for _ in range(audio_rvq_levels) ]
|
||||||
bins = [ [] for _ in range(audio_rvq_levels) ]
|
for pos in range( resp_len ):
|
||||||
for pos in range( resp_len ):
|
rvq = pos % audio_rvq_levels
|
||||||
rvq = pos % audio_rvq_levels
|
bins[rvq].append( resp_list[i][pos] )
|
||||||
bins[rvq].append( resp_list[i][pos] )
|
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
||||||
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
|
bins = bins[:nearest]
|
||||||
bins = bins[:nearest]
|
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
||||||
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
|
|
||||||
|
|
||||||
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)
|
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)
|
||||||
|
|
||||||
|
|
|
@ -158,7 +158,22 @@ class Model(LlmArchClass):
|
||||||
|
|
||||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if SELECTED_ARCH == "mamba":
|
||||||
|
kwargs["cg"] = True
|
||||||
|
|
||||||
|
if "attention_mask" in kwargs:
|
||||||
|
kwargs.pop("attention_mask")
|
||||||
|
|
||||||
|
if "do_sample" in kwargs:
|
||||||
|
kwargs.pop("do_sample")
|
||||||
|
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
|
@ -239,13 +254,9 @@ def example_usage():
|
||||||
proms_list = proms_list[:1]
|
proms_list = proms_list[:1]
|
||||||
resps_list = resps_list[:1]
|
resps_list = resps_list[:1]
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list)
|
|
||||||
target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100)
|
|
||||||
prefix_input_ids, prefix_attention_mask = fold_inputs(text_list, proms_list)
|
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
model = Model(**kwargs).to(device)
|
model = Model(**kwargs).to(device)
|
||||||
steps = 50
|
steps = 50 if cfg.model.interleave else 250
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||||
|
@ -312,15 +323,46 @@ def example_usage():
|
||||||
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*60 ):
|
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
if SELECTED_ARCH == "mamba":
|
target_length = 0
|
||||||
output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3)
|
resp_list = None
|
||||||
|
if cfg.model.interleave:
|
||||||
|
input_ids, attention_mask = fold_inputs(text_list, proms_list)
|
||||||
|
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
|
||||||
|
|
||||||
|
unfolded = unfold_outputs( output )
|
||||||
|
resp_list = unfolded["resp_list"]
|
||||||
else:
|
else:
|
||||||
output = model.generate(input_ids=prefix_input_ids, attention_mask=prefix_attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
|
resp_list = [ [] for _ in range(len(text_list)) ]
|
||||||
|
for l in range(cfg.model.max_levels):
|
||||||
|
quant_levels = [ l ]
|
||||||
|
input_ids, attention_mask = fold_inputs(text_list, proms_list, quant_levels=quant_levels)
|
||||||
|
min_length = len(input_ids[0])
|
||||||
|
|
||||||
unfolded = unfold_outputs( output )
|
output = model.generate(
|
||||||
for i, batch in enumerate(unfolded["resp_list"]):
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
min_length=min_length+(steps if l > 0 else 0),
|
||||||
|
max_length=min_length+steps,
|
||||||
|
eos_token_id=3 if l == 0 else None ,
|
||||||
|
do_sample=False
|
||||||
|
)
|
||||||
|
|
||||||
|
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||||
|
|
||||||
|
if l == 0:
|
||||||
|
steps = 0
|
||||||
|
|
||||||
|
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||||
|
if l == 0:
|
||||||
|
steps = max( steps, resp.shape[0] )
|
||||||
|
resp_list[batch].append( resp )
|
||||||
|
|
||||||
|
for i, resp in enumerate( resp_list ):
|
||||||
|
resp_list[i] = torch.stack( resp ).t()
|
||||||
|
|
||||||
|
for i, batch in enumerate(resp_list):
|
||||||
_ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
_ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||||
|
|
||||||
unload_model()
|
unload_model()
|
||||||
|
@ -330,6 +372,13 @@ def example_usage():
|
||||||
t = trange(steps)
|
t = trange(steps)
|
||||||
for i in t:
|
for i in t:
|
||||||
stats = {"step": i}
|
stats = {"step": i}
|
||||||
|
|
||||||
|
batch_size = len(text_list)
|
||||||
|
quant_levels = None if cfg.model.interleave else torch.randint(0, cfg.model.max_levels, (batch_size,))
|
||||||
|
|
||||||
|
input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list, quant_levels=quant_levels)
|
||||||
|
target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100, quant_levels=quant_levels)
|
||||||
|
|
||||||
if SELECTED_ARCH == "mamba":
|
if SELECTED_ARCH == "mamba":
|
||||||
stats |= engine.traverse(input_ids=input_ids, labels=target_ids)
|
stats |= engine.traverse(input_ids=input_ids, labels=target_ids)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user