copy pasted from test to actual trainer

This commit is contained in:
mrq 2024-06-04 18:40:30 -05:00
parent 0aa01ba31a
commit ed3aeaf3a1
3 changed files with 67 additions and 16 deletions

View File

@ -106,6 +106,8 @@ def fold_inputs(
if quant_levels is not None: if quant_levels is not None:
# grab the previous rvq level # grab the previous rvq level
quant_level = quant_levels[i] - 1 quant_level = quant_levels[i] - 1
# way to signal we want to inference for rvq level 0
# without it, it's a random chance for any level to be selected again
if quant_level < 0: if quant_level < 0:
seq = sep seq = sep
else: else:

View File

@ -370,13 +370,11 @@ def example_usage():
else: else:
resp_list = [ [] for _ in range(len(text_list)) ] resp_list = [ [] for _ in range(len(text_list)) ]
for l in range(cfg.model.max_levels): for l in range(cfg.model.max_levels):
quant_levels = [ l ] quant_levels = [ [ l ] for _ in range(len(text_list)) ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True) input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True)
min_length = len(input_ids[0]) + 1 min_length = len(input_ids[0]) + 1
# print( "input:", l, input_ids.shape, input_ids )
output = model.generate( output = model.generate(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@ -385,8 +383,6 @@ def example_usage():
eos_token_id=3, eos_token_id=3,
do_sample=False do_sample=False
) )
# print( "output:", l, output.shape, output )
unfolded = unfold_outputs( output, quant_levels=quant_levels ) unfolded = unfold_outputs( output, quant_levels=quant_levels )
@ -395,7 +391,6 @@ def example_usage():
for batch, resp in enumerate(unfolded["resp_list"]): for batch, resp in enumerate(unfolded["resp_list"]):
length = resp.shape[-1] length = resp.shape[-1]
print( "LEN:", resp.shape, steps )
# store length # store length
if l == 0: if l == 0:
@ -433,7 +428,7 @@ def example_usage():
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels) target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask) stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
stats |= {"grad_norm": engine.get_global_grad_norm(), "quant_level": quant_levels[0].item()} stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}") tqdm.write(f"{stats}")

View File

@ -27,20 +27,32 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch): def train_feeder(engine, batch):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
if engine.hyper_config.experimental: if engine.hyper_config.experimental:
batch_size = len(batch["text"])
if cfg.model.interleave:
quant_levels = None
resps_list = [ resp for resp in resp_list ]
else:
quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,))
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
input_ids, attention_mask = fold_inputs( input_ids, attention_mask = fold_inputs(
text_list=batch["text"], text_list=batch["text"],
prom_list=batch["proms"], prom_list=batch["proms"],
resp_list=batch["resps"], resp_list=resps_list,
targ_list=batch["resps"],
quant_levels=quant_levels,
) )
target_ids, target_attention_mask = fold_inputs( target_ids, target_attention_mask = fold_inputs(
text_list=batch["text"], text_list=batch["text"],
prom_list=batch["proms"], prom_list=batch["proms"],
resp_list=batch["resps"], resp_list=resps_list,
targ_list=batch["resps"],
quant_levels=quant_levels,
ignore_index=-100 ignore_index=-100
) )
engine( engine(
input_ids=input_ids, input_ids=input_ids,
labels=target_ids labels=target_ids,
) )
else: else:
engine( engine(
@ -107,12 +119,54 @@ def run_eval(engines, eval_name, dl):
engine = engines[name] engine = engines[name]
if engine.hyper_config.experimental: if engine.hyper_config.experimental:
input_ids, attention_mask = fold_inputs( if cfg.model.interleave:
text_list=batch["text"], input_ids, attention_mask = fold_inputs(
prom_list=batch["proms"], text_list=batch["text"],
) prom_list=batch["proms"],
output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False) )
resps_list = unfold_outputs( output )["resp_list"] output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
resps_list = unfold_outputs( output )["resp_list"]
else:
steps = cfg.evaluation.steps
resps_list = [ [] for _ in range(len(text_list)) ]
for l in range(cfg.model.max_levels):
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True)
min_length = 1
for batch in input_ids:
min_length = max( min_length, batch.shape[0] )
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
min_length=min_length,
max_length=min_length+steps*(2 if l > 0 else 1),
eos_token_id=3,
do_sample=False
)
unfolded = unfold_outputs( output, quant_levels=quant_levels )
if l == 0:
steps = 0
for batch, resp in enumerate(unfolded["resp_list"]):
length = resp.shape[-1]
# store length
if l == 0:
steps = max( steps, length )
# pad
else:
resp = resp[:steps]
if length < steps:
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
resps_list[batch].append( resp )
for i, resp in enumerate( resps_list ):
resps_list[i] = torch.stack( resp ).t()
else: else:
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ] resps_list = [ r.unsqueeze(-1) for r in resps_list ]