This better work

This commit is contained in:
mrq 2024-11-09 18:04:59 -06:00
parent 8b3d1cf70a
commit c6a38693a2
6 changed files with 244 additions and 28 deletions

View File

@ -56,7 +56,13 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
To-do: fill out this more when it works.
To-do: fill out this more when it works. Getting this to work is a huge pain.
* Some masked transformers do not "inject" any timestep information (Text-To-Image Muse as far as I can tell)
* Others "expose" it by applying a timestep embedding after pre/post attention normalization
* Except F5-TTS only does this pre for its DiTS, but not UnetT
* MaskGCT does it both pre and post
* the test trainier actually degrades the output immensely when doing this
* I'm sure I've seen a masked transformer not have CFG, but most of them seem to do (and all seem to be poorly documentated on specifically how its doing it for my dumb brain)
## Embeddings

View File

@ -258,6 +258,11 @@ class ModelExperimentalSettings:
len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
# to-to: just incorporate this as a task instead
# classifier-free guidance shit
cfg_cond_dropout_p: float = 0.2 # probability to drop out text and audio during training
cfg_text_dropout_p: float = 0.0 # probability to drop out input audio prompt during training
cfg_prom_dropout_p: float = 0.3 # probability to drop out input audio prompt during training
layerskip: bool = False # layerskip compatible model (or training for)
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
layerskip_r: int = 2 # number of layers to factor into early-exit loss calc

View File

@ -30,7 +30,7 @@ except Exception as e:
pass
try:
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaDecoderLayer, LlamaDecoderLayer_Adapted, LlamaForCausalLM
AVAILABLE_ARCHES.append("llama")
except Exception as e:
ERROR_ARCHES["llama"] = e

View File

@ -13,7 +13,7 @@ from transformers.cache_utils import Cache
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
_logger = logging.getLogger(__name__)
@ -304,15 +304,125 @@ class LlamaAttention_Adapted(LlamaAttention):
return attn_output, attn_scores, past_key_value
class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
# apply timestep embedding with attention norm
# I don't have a concrete idea on how helpful this is, as:
# * F5-TTS's UNetT implementation doesn't do this
# * F5-TTS's DiT does this, but only for pre-attention normalization
# * MaskGCT does this for both
# * Muse doesn't do this, but instead appends the timestep embedding
def weigh_by_timestep(
self,
hidden_states,
timesteps,
):
if timesteps is None:
return hidden_states
for i, timestep in enumerate( timesteps ):
# invalid
if not isinstance( timestep, torch.Tensor ):
continue
hidden_states[i] *= timestep
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
timesteps: Optional[list] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.weigh_by_timestep( hidden_states, timesteps )
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.weigh_by_timestep( hidden_states, timesteps )
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class LlamaModel_Adapted(LlamaModel):
def __init__(self, *args, **kwargs):
def __init__(self, config, *args, **kwargs):
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1)
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
self.early_exit_r = kwargs.pop("early_exit_r", 2)
super().__init__(*args, **kwargs)
#super().__init__(*args, **kwargs)
super(LlamaModel, self).__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.layers_n = config.num_hidden_layers
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LlamaDecoderLayer_Adapted(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
self.layers_n = len(self.layers)
def dropoff_layer( self, l ):
if not self.training:
return False
@ -360,6 +470,7 @@ class LlamaModel_Adapted(LlamaModel):
cache_position: Optional[torch.LongTensor] = None,
layer_skip_lambda = None,
timesteps: Optional[list] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -377,8 +488,10 @@ class LlamaModel_Adapted(LlamaModel):
)
use_cache = False
"""
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
"""
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
@ -430,6 +543,7 @@ class LlamaModel_Adapted(LlamaModel):
use_cache,
cache_position,
position_embeddings,
timesteps,
)
else:
layer_outputs = decoder_layer(
@ -441,6 +555,7 @@ class LlamaModel_Adapted(LlamaModel):
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
timesteps=timesteps,
)
if not self.dropoff_layer( l ):
@ -469,6 +584,7 @@ class LlamaModel_Adapted(LlamaModel):
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,

View File

@ -53,7 +53,9 @@ def _dropout_mask( input, p=None ):
t = random.random()
p = math.cos(t * math.pi * 0.5)
return torch.tensor( [ random.random() < p for _ in range( input.shape[0] ) ], dtype=torch.bool, device=input.device )
seq = [ random.random() < p for _ in range( input.shape[0] ) ]
mask = torch.tensor( seq, dtype=torch.bool, device=input.device )
return mask
def clamp(n, lo, hi):
return max(lo, min(n, hi))
@ -646,7 +648,8 @@ class Base(nn.Module):
use_reentrant=False
))
elif self.arch_type == "llama":
LlamaClass = LlamaModel_Adapted if self.layerskip else LlamaModel
LlamaClass = LlamaModel_Adapted if (self.layerskip or "len" in self.capabilities) else LlamaModel
if n_experts <= 1:
self.model = LlamaClass(LlamaConfig(
vocab_size=n_resp_tokens,
@ -664,8 +667,16 @@ class Base(nn.Module):
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
# replace with desired attention
if attention_backend not in HF_ATTENTIONS:
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
# replace with modified Llama
"""
if "len" in self.capabilities:
self.model = ml.replace_attention( self.model, klass=LlamaDecoderLayer_Adapted, target=LlamaDecoderLayer, mode=attention_backend )
"""
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,
@ -866,6 +877,7 @@ class Base(nn.Module):
state = None,
layer_skip_lambda = None,
timesteps = None,
output_attentions = False,
output_hidden_states = False,
@ -896,6 +908,9 @@ class Base(nn.Module):
if self.layerskip and layer_skip_lambda is not None:
kwargs["layer_skip_lambda"] = layer_skip_lambda
if "len" in self.capabilities and timesteps is not None:
kwargs["timesteps"] = timesteps
output = self.model(**kwargs)
x = output["last_hidden_state"]
@ -1036,8 +1051,12 @@ class Base(nn.Module):
p = math.cos(t * math.pi * 0.5)
dropout_mask = _dropout_mask( resps_list[i], p=p )
inputs[i].append( ("timestep", torch.tensor(t, device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
inputs[i].append( ("timestep", torch.tensor([t], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
inputs[i].append( ("dropout_mask", dropout_mask ) )
else:
# in the event it's needed (it makes shit sound worse)
#inputs[i].append( ("timestep", torch.tensor([1.0], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
...
# Audio length prediction task
# Sequence: <text><sep><rvq lvl><prom><sep><len>
@ -1088,7 +1107,6 @@ class Base(nn.Module):
inputs[i].append( ( "text", text_list[i] ) )
else:
raise Exception(f'Unrecognized task: {task_type}')
return inputs
def inputs_to_embeddings(
@ -1139,13 +1157,10 @@ class Base(nn.Module):
for name, input in batch_input:
if name == "dropout_mask":
dropout_mask = input
elif name == "timestep":
timestep = input
for name, input in batch_input:
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
embedding = None
# is already an embedding
if name == "task":
# noop
@ -1162,6 +1177,10 @@ class Base(nn.Module):
embedding = self.langs_emb( input )
elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input
"""
if proms is None:
continue
"""
# to-do: probably insert separators if task requires it?
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
elif name == "tone" and self.tones_emb is not None:
@ -1179,13 +1198,11 @@ class Base(nn.Module):
# if training NAR-len RVQ level 0
elif "len" in self.capabilities and quant_level == 0 and dropout_mask is not None:
embedding = self.resps_emb(
# if masked use masked token, else original token
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
offset = 0,
quant_level = 0,
)
t_emb = self.time_emb( timestep )
embedding += t_emb
# cheat-y way to handle performing STT across all levels
elif task_type in summed_embeddings_task:
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
@ -1227,18 +1244,36 @@ class Base(nn.Module):
continue
embedding[i] = self.dropout_token
elif name == "timestep" and self.time_emb is not None:
embedding = self.time_emb( input )
elif name == "len" and self.len_emb is not None:
embedding = self.len_emb( input )
else:
# should probably raise an exception so things aren't processed silently
continue
batch.append(embedding)
x_list.append( _join( batch, self.sep ) )
return x_list
# get an attribute from a given input list
def get_input(
self,
inputs,
name,
at=None,
):
for batch_index, batch_input in enumerate(inputs):
if at is not None and batch_index != batch_index:
continue
for n, input in batch_input:
if n == name:
return input
return None
# creates position ids from a given input list
# if not unified_position_ids, then each input segment will have its own sequence
def inputs_to_position_ids(
@ -1262,7 +1297,7 @@ class Base(nn.Module):
return 1
# a mask
if name in ["dropout_mask", "timestep"]:
if name in ["dropout_mask"]:
return 0
# list of tokens
@ -1342,6 +1377,7 @@ class Base(nn.Module):
elif name == "resp":
# mask found, apply it
if dropout_mask is not None:
# if mask use original token, else ignore
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
elif self.interleave:
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
@ -1350,6 +1386,8 @@ class Base(nn.Module):
target.append( torch.full_like(input[..., 0], self.ignore_index) )
else:
target.append( input if input.dim() == 1 else input[:, quant_level] )
elif name == "timestep":
target.append( torch.tensor([self.ignore_index], device=input.device) )
elif name in ["text", "quant_level", "lang", "tone", "len"]:
target.append( input )
@ -1579,7 +1617,15 @@ class Base(nn.Module):
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
"""
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
#timesteps = [ inputs[i][-1] if timestep is not None else None for i, timestep in enumerate(timesteps) ]
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
"""
timesteps = []
classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
output = self._forward(
inputs=x,
@ -1589,6 +1635,7 @@ class Base(nn.Module):
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
timesteps=timesteps,
)
logits = output.logits

View File

@ -128,6 +128,10 @@ class NAR(Base):
token_dropout_error = self.config.experimental.token_dropout_error
# RVQ levels to apply token dropout on
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
# CFG
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
# implicitly set it to all levels
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels - 1]
@ -150,12 +154,10 @@ class NAR(Base):
# trim resps to only contain all levels below the target level
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
# tensor to cat for RVQ level 0
text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
# empty string for CFG
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
# I hate python's value/reference semantics so much
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
for i, quant_level, text, resps, proms, task in zip(range(batch_size), quant_levels, text_list, resps_list, proms_list, task_list):
# cap quant_level if it exceeds its corresponding resp/prom
if quant_level >= resps.shape[-1]:
quant_levels[i] = resps.shape[-1] - 1
@ -193,6 +195,24 @@ class NAR(Base):
#resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
...
# apply CFG (should probably only apply to NAR quant level 0)
if task not in text_task:
drop_text = False
drop_audio = False
if random.random() < cfg_prom_dropout_p:
drop_audio = True
if random.random() < cfg_cond_dropout_p:
drop_audio = True
drop_text = True
if drop_text:
text_list[i] = text_start_stop_sequence
if drop_audio:
proms_list[i] = None
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
@ -225,7 +245,7 @@ class NAR(Base):
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
# initial condition
len_list = [ min(l, 75*3) for l in len_list ]
len_list = [ clamp(1, 75*3, l) for l in len_list ]
metrics = []
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
@ -279,6 +299,10 @@ class NAR(Base):
noise_p = math.cos( start_noise * math.pi * 0.5 )
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device )
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
null_prom = None
cfg_strength = 1.0
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
# anneal temperature
temperature = starting_temperature * (steps_until_x0 / max_steps)
@ -294,6 +318,7 @@ class NAR(Base):
is_masked = input_ids == self.stop_token
# setup inputs
resps_list = [ input_ids ]
inputs = _super.inputs(
text_list=text_list,
proms_list=proms_list,
@ -308,11 +333,29 @@ class NAR(Base):
quant_levels=quant_levels,
layer_skip_variables=sampling_layer_skip_variables,
)
if cfg_strength > 0:
null_inputs = _super.inputs(
text_list=[ null_text ],
proms_list=[ null_prom ],
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=[ timestep ],
quant_levels=quant_levels,
)
null_output = _super.forward(
inputs=null_inputs,
quant_levels=quant_levels,
layer_skip_variables=sampling_layer_skip_variables,
)
logits = [ logits + ( logits - null_logits ) * cfg_strength for logits, null_logits in zip(output.logits, null_output.logits) ]
else:
logits = output.logits
# sample with sampler settings
sampling_top_p = 0.9
filtered_sampled = _super.sample(
logits=output.logits,
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
@ -328,7 +371,7 @@ class NAR(Base):
# retrieves unfiltered logits
unfiltered_sampled = _super.sample(
logits=output.logits,
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
temperature=0.0,
@ -504,7 +547,6 @@ def example_usage():
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
text_list = [
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),