From 406ff7bbe129d54b24f4df2542797e543556e43c Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 4 Jun 2024 14:19:52 -0500 Subject: [PATCH] re-implemented config.model.interleave for the HF-compat experimental method --- vall_e/config.py | 2 +- vall_e/data.py | 94 ++++++++++++++++++++++------------- vall_e/models/experimental.py | 71 ++++++++++++++++++++++---- 3 files changed, 121 insertions(+), 46 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 8f2d715..5d835c4 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -216,7 +216,7 @@ class Model: dropout: float = 0.1 # adjustable dropout value loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) kv_heads: int = 0 - experimental: bool = False + experimental: bool = False # for now it sets things to be HF compatible def get(self, name=None): return [ self ] if not name or self.name == name else [] diff --git a/vall_e/data.py b/vall_e/data.py index aa6e1af..3f65c28 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -44,7 +44,8 @@ def fold_inputs( text_tokens = 256, audio_tokens = 1024, - audio_rvq_levels = cfg.model.max_levels + audio_rvq_levels = cfg.model.max_levels, + quant_levels = None, ): def _create_mask(l, device): seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) @@ -75,23 +76,43 @@ def fold_inputs( offset = text_tokens for i, prom in enumerate(prom_list): - if ignore_index is not None: - seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64) + if quant_levels is not None: + quant_level = quant_levels[i] + if ignore_index is not None: + seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to("cpu", dtype=torch.int64) + else: + seq = prom[:, quant_level].to("cpu", dtype=torch.int64) + for idx, token in enumerate( seq ): + token += offset + ( audio_tokens * quant_level ) else: - seq = prom.flatten().to("cpu", dtype=torch.int64) - for idx, token in enumerate( seq ): - token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) + if ignore_index is not None: + seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64) + else: + seq = prom.flatten().to("cpu", dtype=torch.int64) + for idx, token in enumerate( seq ): + token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) input_ids[i].append( seq ) input_ids[i].append( sep ) offset = text_tokens + (audio_tokens * audio_rvq_levels) for i, resp in enumerate(resp_list): - seq = resp.flatten().to("cpu", dtype=torch.int64) - for idx, token in enumerate( seq ): - token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) - input_ids[i].append( seq ) - input_ids[i].append( stop ) + if quant_levels is not None: + quant_level = quant_levels[i] + seq = resp[:, quant_level].to("cpu", dtype=torch.int64) + for idx, token in enumerate( seq ): + token += offset + ( audio_tokens * quant_level ) + + input_ids[i].append( seq ) + if quant_level == 0: + input_ids[i].append( stop ) + else: + seq = resp.flatten().to("cpu", dtype=torch.int64) + for idx, token in enumerate( seq ): + token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) + + input_ids[i].append( seq ) + input_ids[i].append( stop ) for i, batch in enumerate(input_ids): input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64) @@ -99,6 +120,7 @@ def fold_inputs( return list_to_tensor(input_ids) # unfold from one unified token ID space to separate token spaces +# to-do: unfold at a specific RVQ level instead if requested def unfold_outputs( output_ids, @@ -107,7 +129,8 @@ def unfold_outputs( text_tokens = 256, audio_tokens = 1024, - audio_rvq_levels = cfg.model.max_levels + audio_rvq_levels = cfg.model.max_levels, + quant_levels = None, ): device = output_ids.device batch_size = output_ids.shape[0] @@ -129,30 +152,33 @@ def unfold_outputs( elif text_tokens + (audio_tokens * audio_rvq_levels) <= id: resp_list[i].append( (id - text_tokens) % audio_tokens ) - prom_len = len(prom_list[i]) - if prom_len % audio_rvq_levels == 0 and False: - prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t() + if quant_levels is not None: + prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64) + resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=torch.int64) else: - bins = [ [] for _ in range(audio_rvq_levels) ] - for pos in range( prom_len ): - rvq = pos % audio_rvq_levels - bins[rvq].append( prom_list[i][pos] ) - nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels - bins = bins[:nearest] - prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64) + prom_len = len(prom_list[i]) + if prom_len % audio_rvq_levels == 0 and False: + prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t() + else: + bins = [ [] for _ in range(audio_rvq_levels) ] + for pos in range( prom_len ): + rvq = pos % audio_rvq_levels + bins[rvq].append( prom_list[i][pos] ) + nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels + bins = bins[:nearest] + prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64) - - resp_len = len(resp_list[i]) - if len(resp_list[i]) % audio_rvq_levels == 0 and False: - resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t() - else: - bins = [ [] for _ in range(audio_rvq_levels) ] - for pos in range( resp_len ): - rvq = pos % audio_rvq_levels - bins[rvq].append( resp_list[i][pos] ) - nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels - bins = bins[:nearest] - resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64) + resp_len = len(resp_list[i]) + if len(resp_list[i]) % audio_rvq_levels == 0 and False: + resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t() + else: + bins = [ [] for _ in range(audio_rvq_levels) ] + for pos in range( resp_len ): + rvq = pos % audio_rvq_levels + bins[rvq].append( resp_list[i][pos] ) + nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels + bins = bins[:nearest] + resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64) text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64) diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index 98dc744..54ffacf 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -158,7 +158,22 @@ class Model(LlmArchClass): self.backbone.gradient_checkpointing = gradient_checkpointing + def generate( + self, + *args, + **kwargs + ): + if SELECTED_ARCH == "mamba": + kwargs["cg"] = True + if "attention_mask" in kwargs: + kwargs.pop("attention_mask") + + if "do_sample" in kwargs: + kwargs.pop("do_sample") + + return super().forward(*args, **kwargs) + def forward( self, *args, @@ -239,13 +254,9 @@ def example_usage(): proms_list = proms_list[:1] resps_list = resps_list[:1] - input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list) - target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100) - prefix_input_ids, prefix_attention_mask = fold_inputs(text_list, proms_list) - kwargs = {} model = Model(**kwargs).to(device) - steps = 50 + steps = 50 if cfg.model.interleave else 250 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else "" @@ -312,15 +323,46 @@ def example_usage(): print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @torch.inference_mode() - def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*60 ): + def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ): engine.eval() - if SELECTED_ARCH == "mamba": - output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3) + target_length = 0 + resp_list = None + if cfg.model.interleave: + input_ids, attention_mask = fold_inputs(text_list, proms_list) + output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False) + + unfolded = unfold_outputs( output ) + resp_list = unfolded["resp_list"] else: - output = model.generate(input_ids=prefix_input_ids, attention_mask=prefix_attention_mask, max_length=steps, eos_token_id=3, do_sample=False) + resp_list = [ [] for _ in range(len(text_list)) ] + for l in range(cfg.model.max_levels): + quant_levels = [ l ] + input_ids, attention_mask = fold_inputs(text_list, proms_list, quant_levels=quant_levels) + min_length = len(input_ids[0]) - unfolded = unfold_outputs( output ) - for i, batch in enumerate(unfolded["resp_list"]): + output = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + min_length=min_length+(steps if l > 0 else 0), + max_length=min_length+steps, + eos_token_id=3 if l == 0 else None , + do_sample=False + ) + + unfolded = unfold_outputs( output, quant_levels=quant_levels ) + + if l == 0: + steps = 0 + + for batch, resp in enumerate(unfolded["resp_list"]): + if l == 0: + steps = max( steps, resp.shape[0] ) + resp_list[batch].append( resp ) + + for i, resp in enumerate( resp_list ): + resp_list[i] = torch.stack( resp ).t() + + for i, batch in enumerate(resp_list): _ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device) unload_model() @@ -330,6 +372,13 @@ def example_usage(): t = trange(steps) for i in t: stats = {"step": i} + + batch_size = len(text_list) + quant_levels = None if cfg.model.interleave else torch.randint(0, cfg.model.max_levels, (batch_size,)) + + input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list, quant_levels=quant_levels) + target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100, quant_levels=quant_levels) + if SELECTED_ARCH == "mamba": stats |= engine.traverse(input_ids=input_ids, labels=target_ids) else: