copy pasted from test to actual trainer
This commit is contained in:
parent
0aa01ba31a
commit
ed3aeaf3a1
|
@ -106,6 +106,8 @@ def fold_inputs(
|
|||
if quant_levels is not None:
|
||||
# grab the previous rvq level
|
||||
quant_level = quant_levels[i] - 1
|
||||
# way to signal we want to inference for rvq level 0
|
||||
# without it, it's a random chance for any level to be selected again
|
||||
if quant_level < 0:
|
||||
seq = sep
|
||||
else:
|
||||
|
|
|
@ -370,13 +370,11 @@ def example_usage():
|
|||
else:
|
||||
resp_list = [ [] for _ in range(len(text_list)) ]
|
||||
for l in range(cfg.model.max_levels):
|
||||
quant_levels = [ l ]
|
||||
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True)
|
||||
min_length = len(input_ids[0]) + 1
|
||||
|
||||
# print( "input:", l, input_ids.shape, input_ids )
|
||||
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
@ -386,8 +384,6 @@ def example_usage():
|
|||
do_sample=False
|
||||
)
|
||||
|
||||
# print( "output:", l, output.shape, output )
|
||||
|
||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||
|
||||
if l == 0:
|
||||
|
@ -395,7 +391,6 @@ def example_usage():
|
|||
|
||||
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||
length = resp.shape[-1]
|
||||
print( "LEN:", resp.shape, steps )
|
||||
|
||||
# store length
|
||||
if l == 0:
|
||||
|
@ -433,7 +428,7 @@ def example_usage():
|
|||
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
|
||||
|
||||
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm(), "quant_level": quant_levels[0].item()}
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||
|
||||
tqdm.write(f"{stats}")
|
||||
|
||||
|
|
|
@ -27,20 +27,32 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
|
|||
def train_feeder(engine, batch):
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
if engine.hyper_config.experimental:
|
||||
batch_size = len(batch["text"])
|
||||
if cfg.model.interleave:
|
||||
quant_levels = None
|
||||
resps_list = [ resp for resp in resp_list ]
|
||||
else:
|
||||
quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,))
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
text_list=batch["text"],
|
||||
prom_list=batch["proms"],
|
||||
resp_list=batch["resps"],
|
||||
resp_list=resps_list,
|
||||
targ_list=batch["resps"],
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
target_ids, target_attention_mask = fold_inputs(
|
||||
text_list=batch["text"],
|
||||
prom_list=batch["proms"],
|
||||
resp_list=batch["resps"],
|
||||
resp_list=resps_list,
|
||||
targ_list=batch["resps"],
|
||||
quant_levels=quant_levels,
|
||||
ignore_index=-100
|
||||
)
|
||||
engine(
|
||||
input_ids=input_ids,
|
||||
labels=target_ids
|
||||
labels=target_ids,
|
||||
)
|
||||
else:
|
||||
engine(
|
||||
|
@ -107,12 +119,54 @@ def run_eval(engines, eval_name, dl):
|
|||
engine = engines[name]
|
||||
|
||||
if engine.hyper_config.experimental:
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
text_list=batch["text"],
|
||||
prom_list=batch["proms"],
|
||||
)
|
||||
output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
|
||||
resps_list = unfold_outputs( output )["resp_list"]
|
||||
if cfg.model.interleave:
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
text_list=batch["text"],
|
||||
prom_list=batch["proms"],
|
||||
)
|
||||
output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
|
||||
resps_list = unfold_outputs( output )["resp_list"]
|
||||
else:
|
||||
steps = cfg.evaluation.steps
|
||||
resps_list = [ [] for _ in range(len(text_list)) ]
|
||||
for l in range(cfg.model.max_levels):
|
||||
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True)
|
||||
min_length = 1
|
||||
for batch in input_ids:
|
||||
min_length = max( min_length, batch.shape[0] )
|
||||
|
||||
output = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
min_length=min_length,
|
||||
max_length=min_length+steps*(2 if l > 0 else 1),
|
||||
eos_token_id=3,
|
||||
do_sample=False
|
||||
)
|
||||
|
||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||
|
||||
if l == 0:
|
||||
steps = 0
|
||||
|
||||
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||
length = resp.shape[-1]
|
||||
|
||||
# store length
|
||||
if l == 0:
|
||||
steps = max( steps, length )
|
||||
# pad
|
||||
else:
|
||||
resp = resp[:steps]
|
||||
if length < steps:
|
||||
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
|
||||
|
||||
resps_list[batch].append( resp )
|
||||
|
||||
for i, resp in enumerate( resps_list ):
|
||||
resps_list[i] = torch.stack( resp ).t()
|
||||
else:
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
||||
|
|
Loading…
Reference in New Issue
Block a user