diff --git a/vall_e/data.py b/vall_e/data.py index 468c378..0b2ab4a 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -106,6 +106,8 @@ def fold_inputs( if quant_levels is not None: # grab the previous rvq level quant_level = quant_levels[i] - 1 + # way to signal we want to inference for rvq level 0 + # without it, it's a random chance for any level to be selected again if quant_level < 0: seq = sep else: diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index cedf3cc..e712184 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -370,13 +370,11 @@ def example_usage(): else: resp_list = [ [] for _ in range(len(text_list)) ] for l in range(cfg.model.max_levels): - quant_levels = [ l ] + quant_levels = [ [ l ] for _ in range(len(text_list)) ] input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True) min_length = len(input_ids[0]) + 1 - # print( "input:", l, input_ids.shape, input_ids ) - output = model.generate( input_ids=input_ids, attention_mask=attention_mask, @@ -385,8 +383,6 @@ def example_usage(): eos_token_id=3, do_sample=False ) - - # print( "output:", l, output.shape, output ) unfolded = unfold_outputs( output, quant_levels=quant_levels ) @@ -395,7 +391,6 @@ def example_usage(): for batch, resp in enumerate(unfolded["resp_list"]): length = resp.shape[-1] - print( "LEN:", resp.shape, steps ) # store length if l == 0: @@ -433,7 +428,7 @@ def example_usage(): target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels) stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask) - stats |= {"grad_norm": engine.get_global_grad_norm(), "quant_level": quant_levels[0].item()} + stats |= {"grad_norm": engine.get_global_grad_norm()} tqdm.write(f"{stats}") diff --git a/vall_e/train.py b/vall_e/train.py index d03d87d..0087933 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -27,20 +27,32 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu") def train_feeder(engine, batch): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): if engine.hyper_config.experimental: + batch_size = len(batch["text"]) + if cfg.model.interleave: + quant_levels = None + resps_list = [ resp for resp in resp_list ] + else: + quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,)) + resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ] + input_ids, attention_mask = fold_inputs( text_list=batch["text"], prom_list=batch["proms"], - resp_list=batch["resps"], + resp_list=resps_list, + targ_list=batch["resps"], + quant_levels=quant_levels, ) target_ids, target_attention_mask = fold_inputs( text_list=batch["text"], prom_list=batch["proms"], - resp_list=batch["resps"], + resp_list=resps_list, + targ_list=batch["resps"], + quant_levels=quant_levels, ignore_index=-100 ) engine( input_ids=input_ids, - labels=target_ids + labels=target_ids, ) else: engine( @@ -107,12 +119,54 @@ def run_eval(engines, eval_name, dl): engine = engines[name] if engine.hyper_config.experimental: - input_ids, attention_mask = fold_inputs( - text_list=batch["text"], - prom_list=batch["proms"], - ) - output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False) - resps_list = unfold_outputs( output )["resp_list"] + if cfg.model.interleave: + input_ids, attention_mask = fold_inputs( + text_list=batch["text"], + prom_list=batch["proms"], + ) + output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False) + resps_list = unfold_outputs( output )["resp_list"] + else: + steps = cfg.evaluation.steps + resps_list = [ [] for _ in range(len(text_list)) ] + for l in range(cfg.model.max_levels): + quant_levels = [ [ l ] for _ in range(len(text_list)) ] + + input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True) + min_length = 1 + for batch in input_ids: + min_length = max( min_length, batch.shape[0] ) + + output = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + min_length=min_length, + max_length=min_length+steps*(2 if l > 0 else 1), + eos_token_id=3, + do_sample=False + ) + + unfolded = unfold_outputs( output, quant_levels=quant_levels ) + + if l == 0: + steps = 0 + + for batch, resp in enumerate(unfolded["resp_list"]): + length = resp.shape[-1] + + # store length + if l == 0: + steps = max( steps, length ) + # pad + else: + resp = resp[:steps] + if length < steps: + resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ]) + + resps_list[batch].append( resp ) + + for i, resp in enumerate( resps_list ): + resps_list[i] = torch.stack( resp ).t() else: resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) resps_list = [ r.unsqueeze(-1) for r in resps_list ]