From 2a1794c0840215a562b7454612f1cabb31f579d3 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 9 Aug 2024 21:15:01 -0500 Subject: [PATCH] ughghghhhh --- vall_e/data.py | 46 +++++++++++-- vall_e/models/ar_nar.py | 4 +- vall_e/models/arch/mamba.py | 4 +- vall_e/models/base.py | 10 +-- vall_e/models/experimental.py | 126 +++++++++++++++++++++++++++++----- 5 files changed, 159 insertions(+), 31 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index edcf618..c01d01a 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -58,9 +58,11 @@ def fold_inputs( stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1) return (seq < stop).float() # (b t) - def list_to_tensor(x_list: list[Tensor]): + def list_to_tensor(x_list: list[Tensor], mask=True): l = list(map(len, x_list)) x = pad_sequence(x_list).t() + if not mask: + return x m = _create_mask(l, x_list[0].device) m = m.to(x) @@ -68,7 +70,7 @@ def fold_inputs( def process_prom_or_task(i, prom): if prom is None: - return + return 0 if isinstance(prom, str): task = get_task_symmap()[f'<{input}>'] @@ -76,7 +78,8 @@ def fold_inputs( input_ids[i].append( seq ) input_ids[i].append( sep ) - return + + return seq.shape[0] + 1 # deinterleaved if quant_levels is not None: @@ -99,6 +102,11 @@ def fold_inputs( input_ids[i].append( seq ) input_ids[i].append( sep ) + return seq.shape[0] + 1 + + def generate_position_ids( length, sep=True ): + return [ i for i in range( length + (1 if sep else 0) ) ] + """ if quant_levels is not None: resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ] @@ -109,6 +117,7 @@ def fold_inputs( batch_size = len(text_list) input_ids = [ [] for _ in range(batch_size) ] + position_ids = [ [] for _ in range(batch_size) ] offset = 0 @@ -142,17 +151,23 @@ def fold_inputs( seq = text + text_start else: seq = torch.tensor([text_start + text], device=device, dtype=dtype) + input_ids[i].append( seq ) input_ids[i].append( sep ) + position_ids[i].append( generate_position_ids( seq.shape[0] ) ) + # lang tokens for i, lang in enumerate(lang_list): if isinstance(lang, torch.Tensor): seq = lang + lang_start else: seq = torch.tensor([lang_start + lang], device=device, dtype=dtype) + input_ids[i].append( seq ) input_ids[i].append( sep ) + + position_ids[i].append( generate_position_ids( seq.shape[0] ) ) # inject target quant_level if quant_levels is not None: @@ -164,15 +179,20 @@ def fold_inputs( input_ids[i].append( seq ) input_ids[i].append( sep ) + position_ids[i].append( generate_position_ids( seq.shape[0] ) ) + # prom / task tokens for i, prom in enumerate(prom_list): # list of proms with a possible task token + length = 0 if isinstance(prom, list): for p in prom: - process_prom_or_task(i, p) + length += process_prom_or_task(i, p) # raw tensor else: - process_prom_or_task(i, prom) + length += process_prom_or_task(i, prom) + + position_ids[i].append( generate_position_ids( length, sep=False ) ) # tone tokens for i, tone in enumerate(tone_list): @@ -183,6 +203,8 @@ def fold_inputs( input_ids[i].append( seq ) input_ids[i].append( sep ) + position_ids[i].append( generate_position_ids( seq.shape[0] ) ) + # resp tokens for i, resp in enumerate(resp_list): # deinterleaved @@ -205,6 +227,8 @@ def fold_inputs( input_ids[i].append( seq ) input_ids[i].append( stop ) + + position_ids[i].append( generate_position_ids( seq.shape[0] ) ) # interleaved else: seq = resp.flatten().to(device=device, dtype=dtype) @@ -213,6 +237,8 @@ def fold_inputs( input_ids[i].append( seq ) input_ids[i].append( stop ) + + position_ids[i].append( generate_position_ids( seq.shape[0] ) ) # targ list for i, resp in enumerate(targ_list): @@ -225,6 +251,8 @@ def fold_inputs( input_ids[i].append( seq ) input_ids[i].append( stop ) + + position_ids[i].append( generate_position_ids( seq.shape[0] ) ) # interleaved else: seq = resp.flatten().to(device=device, dtype=dtype) @@ -233,11 +261,17 @@ def fold_inputs( input_ids[i].append( seq ) input_ids[i].append( stop ) + + position_ids[i].append( generate_position_ids( seq.shape[0] ) ) for i, batch in enumerate(input_ids): input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=dtype) + position_ids[i] = torch.concat([ torch.tensor(ids, device=device, dtype=dtype) for ids in position_ids[i] ], dim=-1) - return list_to_tensor(input_ids) + input_ids, attention_mask = list_to_tensor(input_ids) + position_ids = list_to_tensor(position_ids, mask=False) + + return input_ids, attention_mask, position_ids # unfold from one unified token ID space to separate token spaces # to-do: unfold at a specific RVQ level instead if requested diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 18bea78..da90e17 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -316,7 +316,7 @@ class AR_NAR(Base): def example_usage(): - # cfg.trainer.backend = "local" + cfg.trainer.backend = "local" cfg.hyperparameters.gradient_accumulation_steps = 1 if cfg.audio_backend == "dac": cfg.sample_rate = 44_100 @@ -398,7 +398,7 @@ def example_usage(): tasks = cfg.dataset.tasks_list model = AR_NAR(**kwargs).to(device) - steps = 150 * len(tasks) * cfg.model.experimental.causal_size + steps = 150 * len(tasks) # * cfg.model.experimental.causal_size optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" diff --git a/vall_e/models/arch/mamba.py b/vall_e/models/arch/mamba.py index 4389b13..e9ae498 100644 --- a/vall_e/models/arch/mamba.py +++ b/vall_e/models/arch/mamba.py @@ -9,9 +9,9 @@ def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_ residual = None for layer in self.layers: if self.gradient_checkpointing and hidden_states.requires_grad: - hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False ) + hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, **mixer_kwargs, use_reentrant=False ) else: - hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params ) + hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params, **mixer_kwargs ) if not self.fused_add_norm: residual = (hidden_states + residual) if residual is not None else hidden_states hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0e02116..8987e58 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -701,12 +701,12 @@ class Base(nn.Module): self.model = MambaMixelModel( vocab_size=n_resp_tokens, d_model=d_model, - n_layer=n_layers, - d_intermediate=d_model*4, - ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": False} if self.arch_type == "mamba2" else {}, + n_layer=n_layers*2, + d_intermediate=0, #d_model*2, + ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if self.arch_type == "mamba2" else {}, rms_norm=True, fused_add_norm=True, - residual_in_fp32=False, + residual_in_fp32=True, #attn_layer_idx=attn_layer_idx, #attn_cfg=attn_cfg, #initializer_cfg=initializer_cfg, @@ -722,7 +722,7 @@ class Base(nn.Module): is_encoder_decoder=False, is_decoder=True, use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it) - residual_in_fp32=False, # breaks for AMP inference + residual_in_fp32=True, # breaks for AMP inference )) if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index d7266a2..b652b26 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -59,6 +59,15 @@ class Model(LlmArchClass): # text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop # vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1 + if hf_attention == "auto": + if AVAILABLE_ATTENTIONS: + hf_attention = AVAILABLE_ATTENTIONS[0] + else: + hf_attention = "eager" + + if hf_attention == "xformers": + hf_attention = "mem_efficient" + text_start = 0 text_end = text_start + config.text_tokens @@ -82,17 +91,17 @@ class Model(LlmArchClass): vocab_size = resp_end - if cfg.model.arch_type == "llama": + if config.arch_type == "llama": super().__init__(config=LlamaConfig( vocab_size=vocab_size, hidden_size=d_model, - max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.max_levels * 60, # max-length of 60 seconds + max_position_embeddings=cfg.dataset.frames_per_second * config.max_levels * 60, # max-length of 60 seconds intermediate_size=d_model*4, num_hidden_layers=n_layers, num_attention_heads=n_heads, attention_dropout=p_dropout, num_key_value_heads=n_heads, - sliding_window=cfg.dataset.frames_per_second * cfg.model.max_levels * 12, + sliding_window=cfg.dataset.frames_per_second * config.max_levels * 12, hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, @@ -103,7 +112,7 @@ class Model(LlmArchClass): self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( use_reentrant=False )) - elif cfg.model.arch_type == "retnet": + elif config.arch_type == "retnet": super().__init__(config=RetNetConfig( vocab_size=vocab_size, decoder_embed_dim=d_model, @@ -125,16 +134,16 @@ class Model(LlmArchClass): decoder_normalize_before=True, )) - elif cfg.model.arch_type in ["mamba","mamba2"]: + elif config.arch_type in ["mamba","mamba2"]: super().__init__(config=MambaConfig( vocab_size=vocab_size, d_model=d_model, - n_layer=n_layers, - d_intermediate=d_model*4, - ssm_cfg={"layer": "Mamba2"} if cfg.model.arch_type == "mamba2" else {}, + n_layer=n_layers*2, + d_intermediate=0, # d_model*4, + ssm_cfg={"layer": "Mamba2", "use_mem_eff_path": True} if config.arch_type == "mamba2" else {}, rms_norm=True, fused_add_norm=True, - residual_in_fp32=True, + residual_in_fp32=False, )) self.backbone.gradient_checkpointing = gradient_checkpointing @@ -163,8 +172,87 @@ class Model(LlmArchClass): if "min_length" in kwargs: kwargs.pop("min_length") + + """ + if "position_ids" in kwargs: + kwargs.pop("position_ids") + + if "max_new_tokens" in kwargs: + kwargs.pop("max_new_tokens") + if "max_length" not in kwargs: + kwargs["max_length"] = 500 * (self.hyper_config.resp_levels if self.hyper_config.experimental.interleave else 1) + + if "num_last_tokens" not in kwargs: + kwargs["num_last_tokens"] = self.hyper_config.experimental.causal_size + """ + + input_ids = kwargs.pop("input_ids") + attention_mask = kwargs.pop("attention_mask", None) + position_ids = kwargs.pop("position_ids", None) + + stop_token = kwargs.pop("eos_token_id", 3) + max_steps = kwargs.pop("max_new_tokens", 500) + + device = input_ids.device + batch_size = input_ids.shape[0] + + sequence_list = [ inputs for inputs in input_ids ] + position_list = [ positions for positions in position_ids ] + + start_positions = [ inputs.shape[0] for inputs in input_ids ] + + stopped = torch.zeros(batch_size, device=device).bool() + + config = self.hyper_config + state = None + disable_tqdm = False + causal_size = config.experimental.causal_size + + # get next in sequence + for n in trange(max_steps // max(1, causal_size), desc="AR", disable=disable_tqdm): + output = super().forward( + input_ids=torch.stack(sequence_list), + #attention_mask=attention_mask, + #past_key_values=state, + #position_ids=torch.stack(position_list), + #use_cache=False, + #return_dict=False + ) + + logits = output[0] + # state = output[1] + + r = [ logit[-causal_size:].argmax(dim=1) for logit in logits ] + + # append tokens + for i, ri in enumerate(r): + if stop_token in ri: + stopped[i] = True + + last_position_id = position_list[i][-1].item() + 1 + sequence_list[i] = torch.cat([ sequence_list[i], ri.to(device) ], dim=0) + #position_list[i] = torch.cat([ position_list[i], torch.tensor([ last_position_id + _ for _ in range( ri.shape[0] ) ], device=device, dtype=torch.int32) ]) + + # stop token found + stopped |= r == stop_token + if stopped.all().item(): + break + + def _prune(l: Tensor, stop = stop_token): + indices = (l == stop).nonzero() + + if len(indices) == 0: + return l + + return l[: indices.min().item()] + + sequence_list = [ _prune(seq[start_positions[i]:], stop_token) for i, seq in enumerate(sequence_list) ] + return torch.stack(sequence_list) + + """ return super().generate(*args, **kwargs) + """ def forward( self, @@ -188,14 +276,14 @@ class Model(LlmArchClass): if training: quant_levels = None if config.experimental.interleave else [ random.randint( 0 if "ar" in config.capabilities else 1, config.max_levels - 1) for _ in range(batch_size) ] - input_ids, attention_mask = fold_inputs( + input_ids, attention_mask, position_ids = fold_inputs( text_list=text_list, prom_list=proms_list, resp_list=resps_list, targ_list=resps_list, quant_levels=quant_levels, ) - target_ids, target_attention_mask = fold_inputs( + target_ids, target_attention_mask, target_position_ids = fold_inputs( text_list=text_list, prom_list=proms_list, resp_list=resps_list, @@ -206,14 +294,16 @@ class Model(LlmArchClass): return self.forward( input_ids=input_ids, labels=target_ids, + position_ids=position_ids, quant_levels=quant_levels, ) if config.experimental.interleave: - input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list ) + input_ids, attention_mask, position_ids = fold_inputs( text_list=text_list, prom_list=proms_list ) output = self.generate( input_ids=input_ids, + position_ids=position_ids, attention_mask=attention_mask, eos_token_id=3, do_sample=True, @@ -225,7 +315,7 @@ class Model(LlmArchClass): for l in range(config.max_levels): quant_levels = [ l for _ in range(batch_size) ] - input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels) + input_ids, attention_mask, position_ids = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels) min_length = 1 for batch in input_ids: min_length = max( min_length, batch.shape[0] + 1 ) @@ -234,6 +324,7 @@ class Model(LlmArchClass): output = self.generate( input_ids=input_ids, attention_mask=attention_mask, + position_ids=position_ids, eos_token_id=3, do_sample=True, max_new_tokens=steps, @@ -273,10 +364,13 @@ class Model(LlmArchClass): # i HATE the correct way if labels is not None: + if quant_levels is None: + quant_levels = [0 for _ in range(labels.shape[0])] + # predict the next token for AR, else predict in place loss = sum([ F.cross_entropy( - logit[:-1, :] if quant_level == 0 or "nar" not in config.capabilities else logit, - label[1:] if quant_level == 0 or "nar" not in config.capabilities else label, + logit[:-config.experimental.causal_size, :] if quant_level == 0 or "nar" not in config.capabilities else logit, + label[config.experimental.causal_size:] if quant_level == 0 or "nar" not in config.capabilities else label, ignore_index=-100 ) for logit, label, quant_level in zip( logits, labels, quant_levels ) ]) @@ -372,7 +466,7 @@ def example_usage(): kwargs = {} model = Model(**kwargs).to(device) - steps = 100 if cfg.model.experimental.interleave else 300 + steps = 50 # 100 if cfg.model.experimental.interleave else 300 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""