forgot one crucial detail (you *need* the previous RVQ level to keep coherence between all RVQ levels) (experimental deinterleaved is a bit crusty though)

This commit is contained in:
mrq 2024-06-04 18:30:30 -05:00
parent 2ffad5cb6f
commit 0aa01ba31a
2 changed files with 117 additions and 21 deletions

View File

@ -36,6 +36,7 @@ def fold_inputs(
text_list = [],
prom_list = [],
resp_list = [],
targ_list = [],
ignore_index = None,
@ -46,6 +47,7 @@ def fold_inputs(
audio_tokens = 1024,
audio_rvq_levels = cfg.model.max_levels,
quant_levels = None,
experimental = False
):
def _create_mask(l, device):
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -76,6 +78,7 @@ def fold_inputs(
offset = text_tokens
for i, prom in enumerate(prom_list):
# deinterleaved
if quant_levels is not None:
quant_level = quant_levels[i]
if ignore_index is not None:
@ -84,6 +87,7 @@ def fold_inputs(
seq = prom[:, quant_level].to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * quant_level )
# interleaved
else:
if ignore_index is not None:
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
@ -96,7 +100,38 @@ def fold_inputs(
input_ids[i].append( sep )
offset = text_tokens + (audio_tokens * audio_rvq_levels)
for i, resp in enumerate(resp_list):
# deinterleaved
if quant_levels is not None:
# grab the previous rvq level
quant_level = quant_levels[i] - 1
if quant_level < 0:
seq = sep
else:
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
if isinstance(resp, list):
seq = resp[quant_level].to("cpu", dtype=torch.int64)
else:
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * quant_level )
input_ids[i].append( seq )
input_ids[i].append( stop )
# interleaved
else:
seq = resp.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
input_ids[i].append( seq )
input_ids[i].append( stop )
for i, resp in enumerate(targ_list):
# deinterleaved
if quant_levels is not None:
quant_level = quant_levels[i]
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
@ -104,8 +139,8 @@ def fold_inputs(
token += offset + ( audio_tokens * quant_level )
input_ids[i].append( seq )
if quant_level == 0:
input_ids[i].append( stop )
input_ids[i].append( stop )
# interleaved
else:
seq = resp.flatten().to("cpu", dtype=torch.int64)
for idx, token in enumerate( seq ):
@ -140,9 +175,18 @@ def unfold_outputs(
resp_list = [ [] for _ in range(batch_size) ]
for i, batch in enumerate( output_ids ):
# crigne logic to handle prefix resp for rvq levels > 0
# a better way is to observe if the rvq level increased
should_flush = False
flushed = False
for idx, token in enumerate( batch ):
id = token.item()
if id == sep or id == stop:
if should_flush and quant_levels is not None and quant_levels[i] > 0:
resp_list[i] = []
should_flush = False
flushed = True
continue
if 0 <= id and id < text_tokens:
@ -151,6 +195,8 @@ def unfold_outputs(
prom_list[i].append( (id - text_tokens) % audio_tokens )
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
resp_list[i].append( (id - text_tokens) % audio_tokens )
if not flushed:
should_flush = True
if quant_levels is not None:
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64)

View File

@ -179,6 +179,10 @@ class Model(LlmArchClass):
*args,
**kwargs,
):
if SELECTED_ARCH == "mamba":
if "attention_mask" in kwargs:
kwargs.pop("attention_mask")
output = super().forward(*args, **kwargs)
if SELECTED_ARCH in ["llama", "retnet"]:
@ -241,18 +245,48 @@ def example_usage():
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
]
proms_list = [
prom_list = [
qnt[:cfg.dataset.frames_per_second, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
resps_list = [
resp_list = [
qnt[:, :].to(device),
#qnt[cfg.dataset.frames_per_second:, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
text_list = text_list[:1]
proms_list = proms_list[:1]
resps_list = resps_list[:1]
prom_list = prom_list[:1]
resp_list = resp_list[:1]
if False:
output_list = [ [] ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[0])
unfolded = unfold_outputs( input_ids, quant_levels=[0])
print( 0, "inputs:", input_ids.shape, input_ids )
print( 0, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 0] )
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[1])
unfolded = unfold_outputs( input_ids, quant_levels=[1])
print( 1, "inputs:", input_ids.shape, input_ids )
print( 1, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 1] )
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[2])
unfolded = unfold_outputs( input_ids, quant_levels=[2])
print( 2, "inputs:", input_ids.shape, input_ids )
print( 2, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 2] )
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[3])
unfolded = unfold_outputs( input_ids, quant_levels=[3])
print( 3, "inputs:", input_ids.shape, input_ids )
print( 3, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 3] )
return
kwargs = {}
model = Model(**kwargs).to(device)
@ -328,7 +362,7 @@ def example_usage():
target_length = 0
resp_list = None
if cfg.model.interleave:
input_ids, attention_mask = fold_inputs(text_list, proms_list)
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
unfolded = unfold_outputs( output )
@ -337,17 +371,22 @@ def example_usage():
resp_list = [ [] for _ in range(len(text_list)) ]
for l in range(cfg.model.max_levels):
quant_levels = [ l ]
input_ids, attention_mask = fold_inputs(text_list, proms_list, quant_levels=quant_levels)
min_length = len(input_ids[0])
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True)
min_length = len(input_ids[0]) + 1
# print( "input:", l, input_ids.shape, input_ids )
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
min_length=min_length+(steps if l > 0 else 0),
max_length=min_length+steps,
eos_token_id=3 if l == 0 else None ,
min_length=min_length,
max_length=min_length+steps*2,
eos_token_id=3,
do_sample=False
)
# print( "output:", l, output.shape, output )
unfolded = unfold_outputs( output, quant_levels=quant_levels )
@ -355,8 +394,17 @@ def example_usage():
steps = 0
for batch, resp in enumerate(unfolded["resp_list"]):
length = resp.shape[-1]
print( "LEN:", resp.shape, steps )
# store length
if l == 0:
steps = max( steps, resp.shape[0] )
steps = max( steps, length )
# pad
else:
resp = resp[:steps]
if length < steps:
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
resp_list[batch].append( resp )
for i, resp in enumerate( resp_list ):
@ -375,15 +423,17 @@ def example_usage():
batch_size = len(text_list)
quant_levels = None if cfg.model.interleave else torch.randint(0, cfg.model.max_levels, (batch_size,))
input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list, quant_levels=quant_levels)
target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100, quant_levels=quant_levels)
if SELECTED_ARCH == "mamba":
stats |= engine.traverse(input_ids=input_ids, labels=target_ids)
if quant_levels is not None:
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
else:
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
stats |= {"grad_norm": engine.get_global_grad_norm()}
resps_list = [ resp for resp in resp_list ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resps_list, targ_list=resp_list, quant_levels=quant_levels)
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
stats |= {"grad_norm": engine.get_global_grad_norm(), "quant_level": quant_levels[0].item()}
tqdm.write(f"{stats}")