forgot one crucial detail (you *need* the previous RVQ level to keep coherence between all RVQ levels) (experimental deinterleaved is a bit crusty though)
This commit is contained in:
parent
2ffad5cb6f
commit
0aa01ba31a
|
@ -36,6 +36,7 @@ def fold_inputs(
|
|||
text_list = [],
|
||||
prom_list = [],
|
||||
resp_list = [],
|
||||
targ_list = [],
|
||||
|
||||
ignore_index = None,
|
||||
|
||||
|
@ -46,6 +47,7 @@ def fold_inputs(
|
|||
audio_tokens = 1024,
|
||||
audio_rvq_levels = cfg.model.max_levels,
|
||||
quant_levels = None,
|
||||
experimental = False
|
||||
):
|
||||
def _create_mask(l, device):
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
|
@ -76,6 +78,7 @@ def fold_inputs(
|
|||
|
||||
offset = text_tokens
|
||||
for i, prom in enumerate(prom_list):
|
||||
# deinterleaved
|
||||
if quant_levels is not None:
|
||||
quant_level = quant_levels[i]
|
||||
if ignore_index is not None:
|
||||
|
@ -84,6 +87,7 @@ def fold_inputs(
|
|||
seq = prom[:, quant_level].to("cpu", dtype=torch.int64)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * quant_level )
|
||||
# interleaved
|
||||
else:
|
||||
if ignore_index is not None:
|
||||
seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64)
|
||||
|
@ -96,7 +100,38 @@ def fold_inputs(
|
|||
input_ids[i].append( sep )
|
||||
|
||||
offset = text_tokens + (audio_tokens * audio_rvq_levels)
|
||||
|
||||
for i, resp in enumerate(resp_list):
|
||||
# deinterleaved
|
||||
if quant_levels is not None:
|
||||
# grab the previous rvq level
|
||||
quant_level = quant_levels[i] - 1
|
||||
if quant_level < 0:
|
||||
seq = sep
|
||||
else:
|
||||
# my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples
|
||||
if isinstance(resp, list):
|
||||
seq = resp[quant_level].to("cpu", dtype=torch.int64)
|
||||
else:
|
||||
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
|
||||
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * quant_level )
|
||||
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
# interleaved
|
||||
else:
|
||||
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
||||
for idx, token in enumerate( seq ):
|
||||
token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) )
|
||||
|
||||
input_ids[i].append( seq )
|
||||
input_ids[i].append( stop )
|
||||
|
||||
for i, resp in enumerate(targ_list):
|
||||
# deinterleaved
|
||||
if quant_levels is not None:
|
||||
quant_level = quant_levels[i]
|
||||
seq = resp[:, quant_level].to("cpu", dtype=torch.int64)
|
||||
|
@ -104,8 +139,8 @@ def fold_inputs(
|
|||
token += offset + ( audio_tokens * quant_level )
|
||||
|
||||
input_ids[i].append( seq )
|
||||
if quant_level == 0:
|
||||
input_ids[i].append( stop )
|
||||
input_ids[i].append( stop )
|
||||
# interleaved
|
||||
else:
|
||||
seq = resp.flatten().to("cpu", dtype=torch.int64)
|
||||
for idx, token in enumerate( seq ):
|
||||
|
@ -140,9 +175,18 @@ def unfold_outputs(
|
|||
resp_list = [ [] for _ in range(batch_size) ]
|
||||
|
||||
for i, batch in enumerate( output_ids ):
|
||||
# crigne logic to handle prefix resp for rvq levels > 0
|
||||
# a better way is to observe if the rvq level increased
|
||||
should_flush = False
|
||||
flushed = False
|
||||
for idx, token in enumerate( batch ):
|
||||
id = token.item()
|
||||
if id == sep or id == stop:
|
||||
if should_flush and quant_levels is not None and quant_levels[i] > 0:
|
||||
resp_list[i] = []
|
||||
should_flush = False
|
||||
flushed = True
|
||||
|
||||
continue
|
||||
|
||||
if 0 <= id and id < text_tokens:
|
||||
|
@ -151,6 +195,8 @@ def unfold_outputs(
|
|||
prom_list[i].append( (id - text_tokens) % audio_tokens )
|
||||
elif text_tokens + (audio_tokens * audio_rvq_levels) <= id:
|
||||
resp_list[i].append( (id - text_tokens) % audio_tokens )
|
||||
if not flushed:
|
||||
should_flush = True
|
||||
|
||||
if quant_levels is not None:
|
||||
prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64)
|
||||
|
|
|
@ -179,6 +179,10 @@ class Model(LlmArchClass):
|
|||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if SELECTED_ARCH == "mamba":
|
||||
if "attention_mask" in kwargs:
|
||||
kwargs.pop("attention_mask")
|
||||
|
||||
output = super().forward(*args, **kwargs)
|
||||
|
||||
if SELECTED_ARCH in ["llama", "retnet"]:
|
||||
|
@ -241,18 +245,48 @@ def example_usage():
|
|||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
]
|
||||
proms_list = [
|
||||
prom_list = [
|
||||
qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
resps_list = [
|
||||
resp_list = [
|
||||
qnt[:, :].to(device),
|
||||
#qnt[cfg.dataset.frames_per_second:, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
|
||||
text_list = text_list[:1]
|
||||
proms_list = proms_list[:1]
|
||||
resps_list = resps_list[:1]
|
||||
prom_list = prom_list[:1]
|
||||
resp_list = resp_list[:1]
|
||||
|
||||
if False:
|
||||
output_list = [ [] ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[0])
|
||||
unfolded = unfold_outputs( input_ids, quant_levels=[0])
|
||||
print( 0, "inputs:", input_ids.shape, input_ids )
|
||||
print( 0, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
||||
output_list[0].append( resp_list[0][:, 0] )
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[1])
|
||||
unfolded = unfold_outputs( input_ids, quant_levels=[1])
|
||||
print( 1, "inputs:", input_ids.shape, input_ids )
|
||||
print( 1, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
||||
output_list[0].append( resp_list[0][:, 1] )
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[2])
|
||||
unfolded = unfold_outputs( input_ids, quant_levels=[2])
|
||||
print( 2, "inputs:", input_ids.shape, input_ids )
|
||||
print( 2, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
||||
output_list[0].append( resp_list[0][:, 2] )
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[3])
|
||||
unfolded = unfold_outputs( input_ids, quant_levels=[3])
|
||||
print( 3, "inputs:", input_ids.shape, input_ids )
|
||||
print( 3, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
|
||||
output_list[0].append( resp_list[0][:, 3] )
|
||||
|
||||
return
|
||||
|
||||
kwargs = {}
|
||||
model = Model(**kwargs).to(device)
|
||||
|
@ -328,7 +362,7 @@ def example_usage():
|
|||
target_length = 0
|
||||
resp_list = None
|
||||
if cfg.model.interleave:
|
||||
input_ids, attention_mask = fold_inputs(text_list, proms_list)
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
|
||||
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
|
||||
|
||||
unfolded = unfold_outputs( output )
|
||||
|
@ -337,17 +371,22 @@ def example_usage():
|
|||
resp_list = [ [] for _ in range(len(text_list)) ]
|
||||
for l in range(cfg.model.max_levels):
|
||||
quant_levels = [ l ]
|
||||
input_ids, attention_mask = fold_inputs(text_list, proms_list, quant_levels=quant_levels)
|
||||
min_length = len(input_ids[0])
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True)
|
||||
min_length = len(input_ids[0]) + 1
|
||||
|
||||
# print( "input:", l, input_ids.shape, input_ids )
|
||||
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
min_length=min_length+(steps if l > 0 else 0),
|
||||
max_length=min_length+steps,
|
||||
eos_token_id=3 if l == 0 else None ,
|
||||
min_length=min_length,
|
||||
max_length=min_length+steps*2,
|
||||
eos_token_id=3,
|
||||
do_sample=False
|
||||
)
|
||||
|
||||
# print( "output:", l, output.shape, output )
|
||||
|
||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||
|
||||
|
@ -355,8 +394,17 @@ def example_usage():
|
|||
steps = 0
|
||||
|
||||
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||
length = resp.shape[-1]
|
||||
print( "LEN:", resp.shape, steps )
|
||||
|
||||
# store length
|
||||
if l == 0:
|
||||
steps = max( steps, resp.shape[0] )
|
||||
steps = max( steps, length )
|
||||
# pad
|
||||
else:
|
||||
resp = resp[:steps]
|
||||
if length < steps:
|
||||
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
|
||||
resp_list[batch].append( resp )
|
||||
|
||||
for i, resp in enumerate( resp_list ):
|
||||
|
@ -375,15 +423,17 @@ def example_usage():
|
|||
|
||||
batch_size = len(text_list)
|
||||
quant_levels = None if cfg.model.interleave else torch.randint(0, cfg.model.max_levels, (batch_size,))
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list, quant_levels=quant_levels)
|
||||
target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100, quant_levels=quant_levels)
|
||||
|
||||
if SELECTED_ARCH == "mamba":
|
||||
stats |= engine.traverse(input_ids=input_ids, labels=target_ids)
|
||||
if quant_levels is not None:
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
||||
else:
|
||||
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||
resps_list = [ resp for resp in resp_list ]
|
||||
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resps_list, targ_list=resp_list, quant_levels=quant_levels)
|
||||
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
|
||||
|
||||
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm(), "quant_level": quant_levels[0].item()}
|
||||
|
||||
tqdm.write(f"{stats}")
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user