diff --git a/docs/models.md b/docs/models.md index a1cd6ea..4869b74 100644 --- a/docs/models.md +++ b/docs/models.md @@ -56,7 +56,13 @@ However, having a pure NAR is challenging, as you need to both explicitly provid The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise. * incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation -To-do: fill out this more when it works. +To-do: fill out this more when it works. Getting this to work is a huge pain. +* Some masked transformers do not "inject" any timestep information (Text-To-Image Muse as far as I can tell) +* Others "expose" it by applying a timestep embedding after pre/post attention normalization + * Except F5-TTS only does this pre for its DiTS, but not UnetT + * MaskGCT does it both pre and post + * the test trainier actually degrades the output immensely when doing this +* I'm sure I've seen a masked transformer not have CFG, but most of them seem to do (and all seem to be poorly documentated on specifically how its doing it for my dumb brain) ## Embeddings diff --git a/vall_e/config.py b/vall_e/config.py index a7a750a..f99f093 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -258,6 +258,11 @@ class ModelExperimentalSettings: len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len # to-to: just incorporate this as a task instead + # classifier-free guidance shit + cfg_cond_dropout_p: float = 0.2 # probability to drop out text and audio during training + cfg_text_dropout_p: float = 0.0 # probability to drop out input audio prompt during training + cfg_prom_dropout_p: float = 0.3 # probability to drop out input audio prompt during training + layerskip: bool = False # layerskip compatible model (or training for) #layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters) layerskip_r: int = 2 # number of layers to factor into early-exit loss calc diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index 4eb89d7..65a2a94 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -30,7 +30,7 @@ except Exception as e: pass try: - from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM + from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaDecoderLayer, LlamaDecoderLayer_Adapted, LlamaForCausalLM AVAILABLE_ARCHES.append("llama") except Exception as e: ERROR_ARCHES["llama"] = e diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 25755a1..e8b28b8 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -13,7 +13,7 @@ from transformers.cache_utils import Cache from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.cache_utils import Cache, DynamicCache, StaticCache -from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv _logger = logging.getLogger(__name__) @@ -304,15 +304,125 @@ class LlamaAttention_Adapted(LlamaAttention): return attn_output, attn_scores, past_key_value +class LlamaDecoderLayer_Adapted(LlamaDecoderLayer): + # apply timestep embedding with attention norm + # I don't have a concrete idea on how helpful this is, as: + # * F5-TTS's UNetT implementation doesn't do this + # * F5-TTS's DiT does this, but only for pre-attention normalization + # * MaskGCT does this for both + # * Muse doesn't do this, but instead appends the timestep embedding + def weigh_by_timestep( + self, + hidden_states, + timesteps, + ): + if timesteps is None: + return hidden_states + + for i, timestep in enumerate( timesteps ): + # invalid + if not isinstance( timestep, torch.Tensor ): + continue + hidden_states[i] *= timestep + + return hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + timesteps: Optional[list] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.weigh_by_timestep( hidden_states, timesteps ) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.weigh_by_timestep( hidden_states, timesteps ) + + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + class LlamaModel_Adapted(LlamaModel): - def __init__(self, *args, **kwargs): + def __init__(self, config, *args, **kwargs): self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1) self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1) self.early_exit_r = kwargs.pop("early_exit_r", 2) - super().__init__(*args, **kwargs) + #super().__init__(*args, **kwargs) + super(LlamaModel, self).__init__(config) + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.layers_n = config.num_hidden_layers + + # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.ModuleList( + [LlamaDecoderLayer_Adapted(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() - self.layers_n = len(self.layers) def dropoff_layer( self, l ): if not self.training: return False @@ -360,6 +470,7 @@ class LlamaModel_Adapted(LlamaModel): cache_position: Optional[torch.LongTensor] = None, layer_skip_lambda = None, + timesteps: Optional[list] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -377,8 +488,10 @@ class LlamaModel_Adapted(LlamaModel): ) use_cache = False + """ if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + """ # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = False @@ -430,6 +543,7 @@ class LlamaModel_Adapted(LlamaModel): use_cache, cache_position, position_embeddings, + timesteps, ) else: layer_outputs = decoder_layer( @@ -441,6 +555,7 @@ class LlamaModel_Adapted(LlamaModel): use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + timesteps=timesteps, ) if not self.dropoff_layer( l ): @@ -469,6 +584,7 @@ class LlamaModel_Adapted(LlamaModel): if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index aebd075..21a7e81 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -53,7 +53,9 @@ def _dropout_mask( input, p=None ): t = random.random() p = math.cos(t * math.pi * 0.5) - return torch.tensor( [ random.random() < p for _ in range( input.shape[0] ) ], dtype=torch.bool, device=input.device ) + seq = [ random.random() < p for _ in range( input.shape[0] ) ] + mask = torch.tensor( seq, dtype=torch.bool, device=input.device ) + return mask def clamp(n, lo, hi): return max(lo, min(n, hi)) @@ -646,7 +648,8 @@ class Base(nn.Module): use_reentrant=False )) elif self.arch_type == "llama": - LlamaClass = LlamaModel_Adapted if self.layerskip else LlamaModel + LlamaClass = LlamaModel_Adapted if (self.layerskip or "len" in self.capabilities) else LlamaModel + if n_experts <= 1: self.model = LlamaClass(LlamaConfig( vocab_size=n_resp_tokens, @@ -664,8 +667,16 @@ class Base(nn.Module): attn_implementation=hf_attention, #gradient_checkpointing=self.gradient_checkpointing, )) + + # replace with desired attention if attention_backend not in HF_ATTENTIONS: self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) + + # replace with modified Llama + """ + if "len" in self.capabilities: + self.model = ml.replace_attention( self.model, klass=LlamaDecoderLayer_Adapted, target=LlamaDecoderLayer, mode=attention_backend ) + """ else: self.model = MixtralModel(MixtralConfig( vocab_size =n_resp_tokens, @@ -866,6 +877,7 @@ class Base(nn.Module): state = None, layer_skip_lambda = None, + timesteps = None, output_attentions = False, output_hidden_states = False, @@ -896,6 +908,9 @@ class Base(nn.Module): if self.layerskip and layer_skip_lambda is not None: kwargs["layer_skip_lambda"] = layer_skip_lambda + if "len" in self.capabilities and timesteps is not None: + kwargs["timesteps"] = timesteps + output = self.model(**kwargs) x = output["last_hidden_state"] @@ -1036,8 +1051,12 @@ class Base(nn.Module): p = math.cos(t * math.pi * 0.5) dropout_mask = _dropout_mask( resps_list[i], p=p ) - inputs[i].append( ("timestep", torch.tensor(t, device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) ) + inputs[i].append( ("timestep", torch.tensor([t], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) ) inputs[i].append( ("dropout_mask", dropout_mask ) ) + else: + # in the event it's needed (it makes shit sound worse) + #inputs[i].append( ("timestep", torch.tensor([1.0], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) ) + ... # Audio length prediction task # Sequence: @@ -1088,7 +1107,6 @@ class Base(nn.Module): inputs[i].append( ( "text", text_list[i] ) ) else: raise Exception(f'Unrecognized task: {task_type}') - return inputs def inputs_to_embeddings( @@ -1139,13 +1157,10 @@ class Base(nn.Module): for name, input in batch_input: if name == "dropout_mask": dropout_mask = input - elif name == "timestep": - timestep = input for name, input in batch_input: # technically can provide a map for input_name => embedding, but some embedding requires additional processing embedding = None - # is already an embedding if name == "task": # noop @@ -1162,6 +1177,10 @@ class Base(nn.Module): embedding = self.langs_emb( input ) elif name == "prom": proms = [ input ] if isinstance(input, torch.Tensor) else input + """ + if proms is None: + continue + """ # to-do: probably insert separators if task requires it? embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] ) elif name == "tone" and self.tones_emb is not None: @@ -1179,13 +1198,11 @@ class Base(nn.Module): # if training NAR-len RVQ level 0 elif "len" in self.capabilities and quant_level == 0 and dropout_mask is not None: embedding = self.resps_emb( + # if masked use masked token, else original token torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ), offset = 0, quant_level = 0, ) - - t_emb = self.time_emb( timestep ) - embedding += t_emb # cheat-y way to handle performing STT across all levels elif task_type in summed_embeddings_task: # we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT...... @@ -1227,18 +1244,36 @@ class Base(nn.Module): continue embedding[i] = self.dropout_token - + elif name == "timestep" and self.time_emb is not None: + embedding = self.time_emb( input ) elif name == "len" and self.len_emb is not None: embedding = self.len_emb( input ) else: # should probably raise an exception so things aren't processed silently continue + batch.append(embedding) x_list.append( _join( batch, self.sep ) ) return x_list + # get an attribute from a given input list + def get_input( + self, + inputs, + name, + at=None, + ): + for batch_index, batch_input in enumerate(inputs): + if at is not None and batch_index != batch_index: + continue + + for n, input in batch_input: + if n == name: + return input + return None + # creates position ids from a given input list # if not unified_position_ids, then each input segment will have its own sequence def inputs_to_position_ids( @@ -1262,7 +1297,7 @@ class Base(nn.Module): return 1 # a mask - if name in ["dropout_mask", "timestep"]: + if name in ["dropout_mask"]: return 0 # list of tokens @@ -1342,6 +1377,7 @@ class Base(nn.Module): elif name == "resp": # mask found, apply it if dropout_mask is not None: + # if mask use original token, else ignore target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) ) elif self.interleave: target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) ) @@ -1350,6 +1386,8 @@ class Base(nn.Module): target.append( torch.full_like(input[..., 0], self.ignore_index) ) else: target.append( input if input.dim() == 1 else input[:, quant_level] ) + elif name == "timestep": + target.append( torch.tensor([self.ignore_index], device=input.device) ) elif name in ["text", "quant_level", "lang", "tone", "len"]: target.append( input ) @@ -1579,7 +1617,15 @@ class Base(nn.Module): #position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None - classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] + tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ] + """ + timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ] + #timesteps = [ inputs[i][-1] if timestep is not None else None for i, timestep in enumerate(timesteps) ] + timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ] + """ + timesteps = [] + + classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] output = self._forward( inputs=x, @@ -1589,6 +1635,7 @@ class Base(nn.Module): output_attentions = output_attentions, output_hidden_states = output_hidden_states, layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None, + timesteps=timesteps, ) logits = output.logits diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 5855580..c3c5690 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -128,6 +128,10 @@ class NAR(Base): token_dropout_error = self.config.experimental.token_dropout_error # RVQ levels to apply token dropout on token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels + # CFG + cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0 + cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0 + cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0 # implicitly set it to all levels if not token_dropout_rvq_levels: token_dropout_rvq_levels = [0, self.resp_levels - 1] @@ -150,12 +154,10 @@ class NAR(Base): # trim resps to only contain all levels below the target level resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] - - # tensor to cat for RVQ level 0 - text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16) - audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16) + # empty string for CFG + text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16) # I hate python's value/reference semantics so much - for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): + for i, quant_level, text, resps, proms, task in zip(range(batch_size), quant_levels, text_list, resps_list, proms_list, task_list): # cap quant_level if it exceeds its corresponding resp/prom if quant_level >= resps.shape[-1]: quant_levels[i] = resps.shape[-1] - 1 @@ -193,6 +195,24 @@ class NAR(Base): #resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) ... + # apply CFG (should probably only apply to NAR quant level 0) + if task not in text_task: + drop_text = False + drop_audio = False + + if random.random() < cfg_prom_dropout_p: + drop_audio = True + + if random.random() < cfg_cond_dropout_p: + drop_audio = True + drop_text = True + + if drop_text: + text_list[i] = text_start_stop_sequence + + if drop_audio: + proms_list[i] = None + inputs = self.inputs( text_list=text_list, proms_list=proms_list, @@ -225,7 +245,7 @@ class NAR(Base): sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer # initial condition - len_list = [ min(l, 75*3) for l in len_list ] + len_list = [ clamp(1, 75*3, l) for l in len_list ] metrics = [] mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device) @@ -279,6 +299,10 @@ class NAR(Base): noise_p = math.cos( start_noise * math.pi * 0.5 ) input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device ) + null_text = torch.tensor([1, 2], device=device, dtype=torch.int16) + null_prom = None + cfg_strength = 1.0 + for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))): # anneal temperature temperature = starting_temperature * (steps_until_x0 / max_steps) @@ -294,6 +318,7 @@ class NAR(Base): is_masked = input_ids == self.stop_token # setup inputs resps_list = [ input_ids ] + inputs = _super.inputs( text_list=text_list, proms_list=proms_list, @@ -308,11 +333,29 @@ class NAR(Base): quant_levels=quant_levels, layer_skip_variables=sampling_layer_skip_variables, ) + if cfg_strength > 0: + null_inputs = _super.inputs( + text_list=[ null_text ], + proms_list=[ null_prom ], + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + time_list=[ timestep ], + quant_levels=quant_levels, + ) + null_output = _super.forward( + inputs=null_inputs, + quant_levels=quant_levels, + layer_skip_variables=sampling_layer_skip_variables, + ) + logits = [ logits + ( logits - null_logits ) * cfg_strength for logits, null_logits in zip(output.logits, null_output.logits) ] + else: + logits = output.logits # sample with sampler settings sampling_top_p = 0.9 filtered_sampled = _super.sample( - logits=output.logits, + logits=logits, prev_list=prev_list, quant_levels=quant_levels, @@ -328,7 +371,7 @@ class NAR(Base): # retrieves unfiltered logits unfiltered_sampled = _super.sample( - logits=output.logits, + logits=logits, prev_list=prev_list, quant_levels=quant_levels, temperature=0.0, @@ -504,7 +547,6 @@ def example_usage(): qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") - text_list = [ tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), #tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),