re-implemented config.model.interleave for the HF-compat experimental method

This commit is contained in:
mrq 2024-06-04 14:19:52 -05:00
parent c93d5863fd
commit 406ff7bbe1
3 changed files with 121 additions and 46 deletions

View File

@ -216,7 +216,7 @@ class Model:
dropout: float = 0.1 # adjustable dropout value
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
kv_heads: int = 0
experimental: bool = False
experimental: bool = False # for now it sets things to be HF compatible
def get(self, name=None):
return [ self ] if not name or self.name == name else []

View File

@ -44,7 +44,8 @@ def fold_inputs(
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.max_levels
audio_rvq_levels = cfg.model.max_levels,
quant_levels = None,
):
def _create_mask(l, device):
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -75,23 +76,43 @@ def fold_inputs(
offset = text_tokens
for i, prom in enumerate(prom_list):
if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
if quant_levels is not None:
quant_level = quant_levels[i]
if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to("cpu", dtype=torch.int64)
else:
seq = prom[:, quant_level].to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * quant_level )
else:
seq = prom.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
else:
seq = prom.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( sep )
offset = text_tokens + (audio_tokens * audio_rvq_levels)
for i, resp in enumerate(resp_list):
seq = resp.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( stop )
if quant_levels is not None:
quant_level = quant_levels[i]
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * quant_level )
input_ids[i].append( seq )
if quant_level == 0:
input_ids[i].append( stop )
else:
seq = resp.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( stop )
for i, batch in enumerate(input_ids):
input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64)
@ -99,6 +120,7 @@ def fold_inputs(
return list_to_tensor(input_ids)
# unfold from one unified token ID space to separate token spaces
# to-do: unfold at a specific RVQ level instead if requested
def unfold_outputs(
output_ids,
@ -107,7 +129,8 @@ def unfold_outputs(
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.max_levels
audio_rvq_levels = cfg.model.max_levels,
quant_levels = None,
):
device = output_ids.device
batch_size = output_ids.shape[0]
@ -129,30 +152,33 @@ def unfold_outputs(
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
resp_list[i].append( (id - text_tokens) % audio_tokens )
prom_len = len(prom_list[i])
if prom_len % audio_rvq_levels == 0 and False:
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
if quant_levels is not None:
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64)
resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=torch.int64)
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( prom_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( prom_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
prom_len = len(prom_list[i])
if prom_len % audio_rvq_levels == 0 and False:
prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( prom_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( prom_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
resp_len = len(resp_list[i])
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( resp_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( resp_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
resp_len = len(resp_list[i])
if len(resp_list[i]) % audio_rvq_levels == 0 and False:
resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t()
else:
bins = [ [] for _ in range(audio_rvq_levels) ]
for pos in range( resp_len ):
rvq = pos % audio_rvq_levels
bins[rvq].append( resp_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)

View File

@ -158,7 +158,22 @@ class Model(LlmArchClass):
self.backbone.gradient_checkpointing = gradient_checkpointing
def generate(
self,
*args,
**kwargs
):
if SELECTED_ARCH == "mamba":
kwargs["cg"] = True
if "attention_mask" in kwargs:
kwargs.pop("attention_mask")
if "do_sample" in kwargs:
kwargs.pop("do_sample")
return super().forward(*args, **kwargs)
def forward(
self,
*args,
@ -239,13 +254,9 @@ def example_usage():
proms_list = proms_list[:1]
resps_list = resps_list[:1]
input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list)
target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100)
prefix_input_ids, prefix_attention_mask = fold_inputs(text_list, proms_list)
kwargs = {}
model = Model(**kwargs).to(device)
steps = 50
steps = 50 if cfg.model.interleave else 250
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
@ -312,15 +323,46 @@ def example_usage():
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode()
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*60 ):
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
engine.eval()
if SELECTED_ARCH == "mamba":
output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3)
target_length = 0
resp_list = None
if cfg.model.interleave:
input_ids, attention_mask = fold_inputs(text_list, proms_list)
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
unfolded = unfold_outputs( output )
resp_list = unfolded["resp_list"]
else:
output = model.generate(input_ids=prefix_input_ids, attention_mask=prefix_attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
resp_list = [ [] for _ in range(len(text_list)) ]
for l in range(cfg.model.max_levels):
quant_levels = [ l ]
input_ids, attention_mask = fold_inputs(text_list, proms_list, quant_levels=quant_levels)
min_length = len(input_ids[0])
unfolded = unfold_outputs( output )
for i, batch in enumerate(unfolded["resp_list"]):
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
min_length=min_length+(steps if l > 0 else 0),
max_length=min_length+steps,
eos_token_id=3 if l == 0 else None ,
do_sample=False
)
unfolded = unfold_outputs( output, quant_levels=quant_levels )
if l == 0:
steps = 0
for batch, resp in enumerate(unfolded["resp_list"]):
if l == 0:
steps = max( steps, resp.shape[0] )
resp_list[batch].append( resp )
for i, resp in enumerate( resp_list ):
resp_list[i] = torch.stack( resp ).t()
for i, batch in enumerate(resp_list):
_ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
unload_model()
@ -330,6 +372,13 @@ def example_usage():
t = trange(steps)
for i in t:
stats = {"step": i}
batch_size = len(text_list)
quant_levels = None if cfg.model.interleave else torch.randint(0, cfg.model.max_levels, (batch_size,))
input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list, quant_levels=quant_levels)
target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100, quant_levels=quant_levels)
if SELECTED_ARCH == "mamba":
stats |= engine.traverse(input_ids=input_ids, labels=target_ids)
else: