diff --git a/vall_e/data.py b/vall_e/data.py index f151bcb..4937c51 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -76,6 +76,14 @@ def fold_inputs( input_ids[i].append( sep ) offset = text_tokens + # inject target quant_level + if quant_levels is not None: + for i, rvq in enumerate( quant_levels ): + seq = torch.Tensor([offset + rvq]).to("cpu", dtype=torch.int64) + input_ids[i].append( seq ) + input_ids[i].append( sep ) + + offset = text_tokens + audio_rvq_levels for i, prom in enumerate(prom_list): # deinterleaved if quant_levels is not None: @@ -98,7 +106,7 @@ def fold_inputs( input_ids[i].append( seq ) input_ids[i].append( sep ) - offset = text_tokens + (audio_tokens * audio_rvq_levels) + offset = text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) for i, resp in enumerate(resp_list): # deinterleaved @@ -107,7 +115,10 @@ def fold_inputs( quant_level = quant_levels[i] - 1 # way to signal we want to inference for rvq level 0 # without it, it's a random chance for any level to be selected again + if quant_level < 0: + continue + seq = sep else: # my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples @@ -192,10 +203,10 @@ def unfold_outputs( if 0 <= id and id < text_tokens: text_list[i].append( id ) - elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels): - prom_list[i].append( (id - text_tokens) % audio_tokens ) - elif text_tokens + (audio_tokens * audio_rvq_levels) <= id: - resp_list[i].append( (id - text_tokens) % audio_tokens ) + elif text_tokens + audio_rvq_levels <= id and id < text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels): + prom_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens ) + elif text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) <= id: + resp_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens ) if not flushed: should_flush = True diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index ac499c0..1e167c6 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -71,6 +71,7 @@ try: MambaMixelModel.forward = MambaMixelModel_forward AVAILABLE_ARCHES.append("mamba") + AVAILABLE_ARCHES.append("mamba2") except Exception as e: print("Error importing `mamba` arch:", e) pass @@ -80,7 +81,7 @@ SELECTED_ARCH = cfg.model.arch_type if SELECTED_ARCH not in AVAILABLE_ARCHES: raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available") -if SELECTED_ARCH == "mamba": +if SELECTED_ARCH in ["mamba","mamba2"]: LlmArchClass = MambaLMHeadModel elif SELECTED_ARCH == "llama": LlmArchClass = LlamaForCausalLM @@ -103,7 +104,8 @@ class Model(LlmArchClass): hf_attention = config.attention if config is not None else None gradient_checkpointing = config.gradient_checkpointing if config is not None else True - vocab_size = 256 + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1 + # text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop + vocab_size = 256 + cfg.model.max_levels + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1 if SELECTED_ARCH == "llama": super().__init__(config=LlamaConfig( @@ -148,12 +150,12 @@ class Model(LlmArchClass): decoder_normalize_before=True, )) - elif SELECTED_ARCH == "mamba": + elif SELECTED_ARCH in ["mamba","mamba2"]: super().__init__(config=MambaConfig( vocab_size=vocab_size, d_model=d_model, n_layer=n_layers*2, - #ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan + ssm_cfg={"layer": "Mamba2", "chunk_size":64} if SELECTED_ARCH == "mamba2" else {}, )) self.backbone.gradient_checkpointing = gradient_checkpointing @@ -163,7 +165,7 @@ class Model(LlmArchClass): *args, **kwargs ): - if SELECTED_ARCH == "mamba": + if SELECTED_ARCH in ["mamba","mamba2"]: kwargs["cg"] = True if "attention_mask" in kwargs: @@ -182,7 +184,7 @@ class Model(LlmArchClass): *args, **kwargs, ): - if SELECTED_ARCH == "mamba": + if SELECTED_ARCH in ["mamba","mamba2"]: if "attention_mask" in kwargs: kwargs.pop("attention_mask") @@ -193,7 +195,7 @@ class Model(LlmArchClass): self.loss = dict( nll = output.loss, ) - elif SELECTED_ARCH == "mamba": + elif SELECTED_ARCH in ["mamba","mamba2"]: if "labels" in kwargs: labels = kwargs.pop("labels") logits = output.logits @@ -262,38 +264,15 @@ def example_usage(): prom_list = prom_list[:1] resp_list = resp_list[:1] - if False: - output_list = [ [] ] - - input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[0]) - unfolded = unfold_outputs( input_ids, quant_levels=[0]) - print( 0, "inputs:", input_ids.shape, input_ids ) - print( 0, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] ) - output_list[0].append( resp_list[0][:, 0] ) - - input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[1]) - unfolded = unfold_outputs( input_ids, quant_levels=[1]) - print( 1, "inputs:", input_ids.shape, input_ids ) - print( 1, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] ) - output_list[0].append( resp_list[0][:, 1] ) - - input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[2]) - unfolded = unfold_outputs( input_ids, quant_levels=[2]) - print( 2, "inputs:", input_ids.shape, input_ids ) - print( 2, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] ) - output_list[0].append( resp_list[0][:, 2] ) - - input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[3]) - unfolded = unfold_outputs( input_ids, quant_levels=[3]) - print( 3, "inputs:", input_ids.shape, input_ids ) - print( 3, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] ) - output_list[0].append( resp_list[0][:, 3] ) - - return - kwargs = {} model = Model(**kwargs).to(device) - steps = 50 if cfg.model.interleave else 250 + steps = 100 + if cfg.model.arch_type == "mamba2": + steps = 100 + elif cfg.model.arch_type == "llama": + steps = 500 + elif cfg.model.interleave: + steps = 250 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""