copy pasted from test to actual trainer
This commit is contained in:
parent
0aa01ba31a
commit
ed3aeaf3a1
|
@ -106,6 +106,8 @@ def fold_inputs(
|
||||||
if quant_levels is not None:
|
if quant_levels is not None:
|
||||||
# grab the previous rvq level
|
# grab the previous rvq level
|
||||||
quant_level = quant_levels[i] - 1
|
quant_level = quant_levels[i] - 1
|
||||||
|
# way to signal we want to inference for rvq level 0
|
||||||
|
# without it, it's a random chance for any level to be selected again
|
||||||
if quant_level < 0:
|
if quant_level < 0:
|
||||||
seq = sep
|
seq = sep
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -370,13 +370,11 @@ def example_usage():
|
||||||
else:
|
else:
|
||||||
resp_list = [ [] for _ in range(len(text_list)) ]
|
resp_list = [ [] for _ in range(len(text_list)) ]
|
||||||
for l in range(cfg.model.max_levels):
|
for l in range(cfg.model.max_levels):
|
||||||
quant_levels = [ l ]
|
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True)
|
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True)
|
||||||
min_length = len(input_ids[0]) + 1
|
min_length = len(input_ids[0]) + 1
|
||||||
|
|
||||||
# print( "input:", l, input_ids.shape, input_ids )
|
|
||||||
|
|
||||||
output = model.generate(
|
output = model.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
@ -386,8 +384,6 @@ def example_usage():
|
||||||
do_sample=False
|
do_sample=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# print( "output:", l, output.shape, output )
|
|
||||||
|
|
||||||
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||||
|
|
||||||
if l == 0:
|
if l == 0:
|
||||||
|
@ -395,7 +391,6 @@ def example_usage():
|
||||||
|
|
||||||
for batch, resp in enumerate(unfolded["resp_list"]):
|
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||||
length = resp.shape[-1]
|
length = resp.shape[-1]
|
||||||
print( "LEN:", resp.shape, steps )
|
|
||||||
|
|
||||||
# store length
|
# store length
|
||||||
if l == 0:
|
if l == 0:
|
||||||
|
@ -433,7 +428,7 @@ def example_usage():
|
||||||
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
|
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
|
||||||
|
|
||||||
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
|
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
|
||||||
stats |= {"grad_norm": engine.get_global_grad_norm(), "quant_level": quant_levels[0].item()}
|
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||||
|
|
||||||
tqdm.write(f"{stats}")
|
tqdm.write(f"{stats}")
|
||||||
|
|
||||||
|
|
|
@ -27,20 +27,32 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
|
||||||
def train_feeder(engine, batch):
|
def train_feeder(engine, batch):
|
||||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||||
if engine.hyper_config.experimental:
|
if engine.hyper_config.experimental:
|
||||||
|
batch_size = len(batch["text"])
|
||||||
|
if cfg.model.interleave:
|
||||||
|
quant_levels = None
|
||||||
|
resps_list = [ resp for resp in resp_list ]
|
||||||
|
else:
|
||||||
|
quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,))
|
||||||
|
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
||||||
|
|
||||||
input_ids, attention_mask = fold_inputs(
|
input_ids, attention_mask = fold_inputs(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
prom_list=batch["proms"],
|
prom_list=batch["proms"],
|
||||||
resp_list=batch["resps"],
|
resp_list=resps_list,
|
||||||
|
targ_list=batch["resps"],
|
||||||
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
target_ids, target_attention_mask = fold_inputs(
|
target_ids, target_attention_mask = fold_inputs(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
prom_list=batch["proms"],
|
prom_list=batch["proms"],
|
||||||
resp_list=batch["resps"],
|
resp_list=resps_list,
|
||||||
|
targ_list=batch["resps"],
|
||||||
|
quant_levels=quant_levels,
|
||||||
ignore_index=-100
|
ignore_index=-100
|
||||||
)
|
)
|
||||||
engine(
|
engine(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
labels=target_ids
|
labels=target_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
engine(
|
engine(
|
||||||
|
@ -107,12 +119,54 @@ def run_eval(engines, eval_name, dl):
|
||||||
engine = engines[name]
|
engine = engines[name]
|
||||||
|
|
||||||
if engine.hyper_config.experimental:
|
if engine.hyper_config.experimental:
|
||||||
input_ids, attention_mask = fold_inputs(
|
if cfg.model.interleave:
|
||||||
text_list=batch["text"],
|
input_ids, attention_mask = fold_inputs(
|
||||||
prom_list=batch["proms"],
|
text_list=batch["text"],
|
||||||
)
|
prom_list=batch["proms"],
|
||||||
output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
|
)
|
||||||
resps_list = unfold_outputs( output )["resp_list"]
|
output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
|
||||||
|
resps_list = unfold_outputs( output )["resp_list"]
|
||||||
|
else:
|
||||||
|
steps = cfg.evaluation.steps
|
||||||
|
resps_list = [ [] for _ in range(len(text_list)) ]
|
||||||
|
for l in range(cfg.model.max_levels):
|
||||||
|
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
|
||||||
|
|
||||||
|
input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True)
|
||||||
|
min_length = 1
|
||||||
|
for batch in input_ids:
|
||||||
|
min_length = max( min_length, batch.shape[0] )
|
||||||
|
|
||||||
|
output = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
min_length=min_length,
|
||||||
|
max_length=min_length+steps*(2 if l > 0 else 1),
|
||||||
|
eos_token_id=3,
|
||||||
|
do_sample=False
|
||||||
|
)
|
||||||
|
|
||||||
|
unfolded = unfold_outputs( output, quant_levels=quant_levels )
|
||||||
|
|
||||||
|
if l == 0:
|
||||||
|
steps = 0
|
||||||
|
|
||||||
|
for batch, resp in enumerate(unfolded["resp_list"]):
|
||||||
|
length = resp.shape[-1]
|
||||||
|
|
||||||
|
# store length
|
||||||
|
if l == 0:
|
||||||
|
steps = max( steps, length )
|
||||||
|
# pad
|
||||||
|
else:
|
||||||
|
resp = resp[:steps]
|
||||||
|
if length < steps:
|
||||||
|
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
|
||||||
|
|
||||||
|
resps_list[batch].append( resp )
|
||||||
|
|
||||||
|
for i, resp in enumerate( resps_list ):
|
||||||
|
resps_list[i] = torch.stack( resp ).t()
|
||||||
else:
|
else:
|
||||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||||
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user