This better work
This commit is contained in:
parent
8b3d1cf70a
commit
c6a38693a2
|
@ -56,7 +56,13 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
|
||||||
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
|
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
|
||||||
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
|
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
|
||||||
|
|
||||||
To-do: fill out this more when it works.
|
To-do: fill out this more when it works. Getting this to work is a huge pain.
|
||||||
|
* Some masked transformers do not "inject" any timestep information (Text-To-Image Muse as far as I can tell)
|
||||||
|
* Others "expose" it by applying a timestep embedding after pre/post attention normalization
|
||||||
|
* Except F5-TTS only does this pre for its DiTS, but not UnetT
|
||||||
|
* MaskGCT does it both pre and post
|
||||||
|
* the test trainier actually degrades the output immensely when doing this
|
||||||
|
* I'm sure I've seen a masked transformer not have CFG, but most of them seem to do (and all seem to be poorly documentated on specifically how its doing it for my dumb brain)
|
||||||
|
|
||||||
## Embeddings
|
## Embeddings
|
||||||
|
|
||||||
|
|
|
@ -258,6 +258,11 @@ class ModelExperimentalSettings:
|
||||||
len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
||||||
# to-to: just incorporate this as a task instead
|
# to-to: just incorporate this as a task instead
|
||||||
|
|
||||||
|
# classifier-free guidance shit
|
||||||
|
cfg_cond_dropout_p: float = 0.2 # probability to drop out text and audio during training
|
||||||
|
cfg_text_dropout_p: float = 0.0 # probability to drop out input audio prompt during training
|
||||||
|
cfg_prom_dropout_p: float = 0.3 # probability to drop out input audio prompt during training
|
||||||
|
|
||||||
layerskip: bool = False # layerskip compatible model (or training for)
|
layerskip: bool = False # layerskip compatible model (or training for)
|
||||||
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
|
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
|
||||||
layerskip_r: int = 2 # number of layers to factor into early-exit loss calc
|
layerskip_r: int = 2 # number of layers to factor into early-exit loss calc
|
||||||
|
|
|
@ -30,7 +30,7 @@ except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM
|
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaDecoderLayer, LlamaDecoderLayer_Adapted, LlamaForCausalLM
|
||||||
AVAILABLE_ARCHES.append("llama")
|
AVAILABLE_ARCHES.append("llama")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ERROR_ARCHES["llama"] = e
|
ERROR_ARCHES["llama"] = e
|
||||||
|
|
|
@ -13,7 +13,7 @@ from transformers.cache_utils import Cache
|
||||||
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -304,15 +304,125 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
|
|
||||||
return attn_output, attn_scores, past_key_value
|
return attn_output, attn_scores, past_key_value
|
||||||
|
|
||||||
|
class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
|
||||||
|
# apply timestep embedding with attention norm
|
||||||
|
# I don't have a concrete idea on how helpful this is, as:
|
||||||
|
# * F5-TTS's UNetT implementation doesn't do this
|
||||||
|
# * F5-TTS's DiT does this, but only for pre-attention normalization
|
||||||
|
# * MaskGCT does this for both
|
||||||
|
# * Muse doesn't do this, but instead appends the timestep embedding
|
||||||
|
def weigh_by_timestep(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
timesteps,
|
||||||
|
):
|
||||||
|
if timesteps is None:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
for i, timestep in enumerate( timesteps ):
|
||||||
|
# invalid
|
||||||
|
if not isinstance( timestep, torch.Tensor ):
|
||||||
|
continue
|
||||||
|
hidden_states[i] *= timestep
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
timesteps: Optional[list] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*):
|
||||||
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||||
|
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence
|
||||||
|
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||||
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||||
|
with `head_dim` being the embedding dimension of each attention head.
|
||||||
|
kwargs (`dict`, *optional*):
|
||||||
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||||
|
into the model
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states = self.weigh_by_timestep( hidden_states, timesteps )
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.weigh_by_timestep( hidden_states, timesteps )
|
||||||
|
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
class LlamaModel_Adapted(LlamaModel):
|
class LlamaModel_Adapted(LlamaModel):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, config, *args, **kwargs):
|
||||||
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1)
|
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1)
|
||||||
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
|
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
|
||||||
self.early_exit_r = kwargs.pop("early_exit_r", 2)
|
self.early_exit_r = kwargs.pop("early_exit_r", 2)
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
#super().__init__(*args, **kwargs)
|
||||||
|
super(LlamaModel, self).__init__(config)
|
||||||
|
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.layers_n = config.num_hidden_layers
|
||||||
|
|
||||||
|
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[LlamaDecoderLayer_Adapted(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
self.layers_n = len(self.layers)
|
|
||||||
def dropoff_layer( self, l ):
|
def dropoff_layer( self, l ):
|
||||||
if not self.training:
|
if not self.training:
|
||||||
return False
|
return False
|
||||||
|
@ -360,6 +470,7 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
|
||||||
layer_skip_lambda = None,
|
layer_skip_lambda = None,
|
||||||
|
timesteps: Optional[list] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
@ -377,8 +488,10 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
|
"""
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
"""
|
||||||
|
|
||||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False
|
||||||
|
@ -430,6 +543,7 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
use_cache,
|
use_cache,
|
||||||
cache_position,
|
cache_position,
|
||||||
position_embeddings,
|
position_embeddings,
|
||||||
|
timesteps,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
|
@ -441,6 +555,7 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
|
timesteps=timesteps,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.dropoff_layer( l ):
|
if not self.dropoff_layer( l ):
|
||||||
|
@ -469,6 +584,7 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_cache,
|
past_key_values=next_cache,
|
||||||
|
|
|
@ -53,7 +53,9 @@ def _dropout_mask( input, p=None ):
|
||||||
t = random.random()
|
t = random.random()
|
||||||
p = math.cos(t * math.pi * 0.5)
|
p = math.cos(t * math.pi * 0.5)
|
||||||
|
|
||||||
return torch.tensor( [ random.random() < p for _ in range( input.shape[0] ) ], dtype=torch.bool, device=input.device )
|
seq = [ random.random() < p for _ in range( input.shape[0] ) ]
|
||||||
|
mask = torch.tensor( seq, dtype=torch.bool, device=input.device )
|
||||||
|
return mask
|
||||||
|
|
||||||
def clamp(n, lo, hi):
|
def clamp(n, lo, hi):
|
||||||
return max(lo, min(n, hi))
|
return max(lo, min(n, hi))
|
||||||
|
@ -646,7 +648,8 @@ class Base(nn.Module):
|
||||||
use_reentrant=False
|
use_reentrant=False
|
||||||
))
|
))
|
||||||
elif self.arch_type == "llama":
|
elif self.arch_type == "llama":
|
||||||
LlamaClass = LlamaModel_Adapted if self.layerskip else LlamaModel
|
LlamaClass = LlamaModel_Adapted if (self.layerskip or "len" in self.capabilities) else LlamaModel
|
||||||
|
|
||||||
if n_experts <= 1:
|
if n_experts <= 1:
|
||||||
self.model = LlamaClass(LlamaConfig(
|
self.model = LlamaClass(LlamaConfig(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
|
@ -664,8 +667,16 @@ class Base(nn.Module):
|
||||||
attn_implementation=hf_attention,
|
attn_implementation=hf_attention,
|
||||||
#gradient_checkpointing=self.gradient_checkpointing,
|
#gradient_checkpointing=self.gradient_checkpointing,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# replace with desired attention
|
||||||
if attention_backend not in HF_ATTENTIONS:
|
if attention_backend not in HF_ATTENTIONS:
|
||||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
||||||
|
|
||||||
|
# replace with modified Llama
|
||||||
|
"""
|
||||||
|
if "len" in self.capabilities:
|
||||||
|
self.model = ml.replace_attention( self.model, klass=LlamaDecoderLayer_Adapted, target=LlamaDecoderLayer, mode=attention_backend )
|
||||||
|
"""
|
||||||
else:
|
else:
|
||||||
self.model = MixtralModel(MixtralConfig(
|
self.model = MixtralModel(MixtralConfig(
|
||||||
vocab_size =n_resp_tokens,
|
vocab_size =n_resp_tokens,
|
||||||
|
@ -866,6 +877,7 @@ class Base(nn.Module):
|
||||||
state = None,
|
state = None,
|
||||||
|
|
||||||
layer_skip_lambda = None,
|
layer_skip_lambda = None,
|
||||||
|
timesteps = None,
|
||||||
|
|
||||||
output_attentions = False,
|
output_attentions = False,
|
||||||
output_hidden_states = False,
|
output_hidden_states = False,
|
||||||
|
@ -896,6 +908,9 @@ class Base(nn.Module):
|
||||||
if self.layerskip and layer_skip_lambda is not None:
|
if self.layerskip and layer_skip_lambda is not None:
|
||||||
kwargs["layer_skip_lambda"] = layer_skip_lambda
|
kwargs["layer_skip_lambda"] = layer_skip_lambda
|
||||||
|
|
||||||
|
if "len" in self.capabilities and timesteps is not None:
|
||||||
|
kwargs["timesteps"] = timesteps
|
||||||
|
|
||||||
output = self.model(**kwargs)
|
output = self.model(**kwargs)
|
||||||
x = output["last_hidden_state"]
|
x = output["last_hidden_state"]
|
||||||
|
|
||||||
|
@ -1036,8 +1051,12 @@ class Base(nn.Module):
|
||||||
p = math.cos(t * math.pi * 0.5)
|
p = math.cos(t * math.pi * 0.5)
|
||||||
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
||||||
|
|
||||||
inputs[i].append( ("timestep", torch.tensor(t, device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
inputs[i].append( ("timestep", torch.tensor([t], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||||
|
else:
|
||||||
|
# in the event it's needed (it makes shit sound worse)
|
||||||
|
#inputs[i].append( ("timestep", torch.tensor([1.0], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||||
|
...
|
||||||
|
|
||||||
# Audio length prediction task
|
# Audio length prediction task
|
||||||
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
||||||
|
@ -1088,7 +1107,6 @@ class Base(nn.Module):
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Unrecognized task: {task_type}')
|
raise Exception(f'Unrecognized task: {task_type}')
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def inputs_to_embeddings(
|
def inputs_to_embeddings(
|
||||||
|
@ -1139,13 +1157,10 @@ class Base(nn.Module):
|
||||||
for name, input in batch_input:
|
for name, input in batch_input:
|
||||||
if name == "dropout_mask":
|
if name == "dropout_mask":
|
||||||
dropout_mask = input
|
dropout_mask = input
|
||||||
elif name == "timestep":
|
|
||||||
timestep = input
|
|
||||||
|
|
||||||
for name, input in batch_input:
|
for name, input in batch_input:
|
||||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||||
embedding = None
|
embedding = None
|
||||||
|
|
||||||
# is already an embedding
|
# is already an embedding
|
||||||
if name == "task":
|
if name == "task":
|
||||||
# noop
|
# noop
|
||||||
|
@ -1162,6 +1177,10 @@ class Base(nn.Module):
|
||||||
embedding = self.langs_emb( input )
|
embedding = self.langs_emb( input )
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
|
"""
|
||||||
|
if proms is None:
|
||||||
|
continue
|
||||||
|
"""
|
||||||
# to-do: probably insert separators if task requires it?
|
# to-do: probably insert separators if task requires it?
|
||||||
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
||||||
elif name == "tone" and self.tones_emb is not None:
|
elif name == "tone" and self.tones_emb is not None:
|
||||||
|
@ -1179,13 +1198,11 @@ class Base(nn.Module):
|
||||||
# if training NAR-len RVQ level 0
|
# if training NAR-len RVQ level 0
|
||||||
elif "len" in self.capabilities and quant_level == 0 and dropout_mask is not None:
|
elif "len" in self.capabilities and quant_level == 0 and dropout_mask is not None:
|
||||||
embedding = self.resps_emb(
|
embedding = self.resps_emb(
|
||||||
|
# if masked use masked token, else original token
|
||||||
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
||||||
offset = 0,
|
offset = 0,
|
||||||
quant_level = 0,
|
quant_level = 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
t_emb = self.time_emb( timestep )
|
|
||||||
embedding += t_emb
|
|
||||||
# cheat-y way to handle performing STT across all levels
|
# cheat-y way to handle performing STT across all levels
|
||||||
elif task_type in summed_embeddings_task:
|
elif task_type in summed_embeddings_task:
|
||||||
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
||||||
|
@ -1227,18 +1244,36 @@ class Base(nn.Module):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
embedding[i] = self.dropout_token
|
embedding[i] = self.dropout_token
|
||||||
|
elif name == "timestep" and self.time_emb is not None:
|
||||||
|
embedding = self.time_emb( input )
|
||||||
elif name == "len" and self.len_emb is not None:
|
elif name == "len" and self.len_emb is not None:
|
||||||
embedding = self.len_emb( input )
|
embedding = self.len_emb( input )
|
||||||
else:
|
else:
|
||||||
# should probably raise an exception so things aren't processed silently
|
# should probably raise an exception so things aren't processed silently
|
||||||
continue
|
continue
|
||||||
|
|
||||||
batch.append(embedding)
|
batch.append(embedding)
|
||||||
|
|
||||||
x_list.append( _join( batch, self.sep ) )
|
x_list.append( _join( batch, self.sep ) )
|
||||||
|
|
||||||
return x_list
|
return x_list
|
||||||
|
|
||||||
|
# get an attribute from a given input list
|
||||||
|
def get_input(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
name,
|
||||||
|
at=None,
|
||||||
|
):
|
||||||
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
|
if at is not None and batch_index != batch_index:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for n, input in batch_input:
|
||||||
|
if n == name:
|
||||||
|
return input
|
||||||
|
return None
|
||||||
|
|
||||||
# creates position ids from a given input list
|
# creates position ids from a given input list
|
||||||
# if not unified_position_ids, then each input segment will have its own sequence
|
# if not unified_position_ids, then each input segment will have its own sequence
|
||||||
def inputs_to_position_ids(
|
def inputs_to_position_ids(
|
||||||
|
@ -1262,7 +1297,7 @@ class Base(nn.Module):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
# a mask
|
# a mask
|
||||||
if name in ["dropout_mask", "timestep"]:
|
if name in ["dropout_mask"]:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# list of tokens
|
# list of tokens
|
||||||
|
@ -1342,6 +1377,7 @@ class Base(nn.Module):
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
# mask found, apply it
|
# mask found, apply it
|
||||||
if dropout_mask is not None:
|
if dropout_mask is not None:
|
||||||
|
# if mask use original token, else ignore
|
||||||
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
||||||
elif self.interleave:
|
elif self.interleave:
|
||||||
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
||||||
|
@ -1350,6 +1386,8 @@ class Base(nn.Module):
|
||||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||||
else:
|
else:
|
||||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||||
|
elif name == "timestep":
|
||||||
|
target.append( torch.tensor([self.ignore_index], device=input.device) )
|
||||||
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
||||||
target.append( input )
|
target.append( input )
|
||||||
|
|
||||||
|
@ -1579,7 +1617,15 @@ class Base(nn.Module):
|
||||||
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||||
|
|
||||||
classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
||||||
|
"""
|
||||||
|
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
||||||
|
#timesteps = [ inputs[i][-1] if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
||||||
|
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
||||||
|
"""
|
||||||
|
timesteps = []
|
||||||
|
|
||||||
|
classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
||||||
output = self._forward(
|
output = self._forward(
|
||||||
inputs=x,
|
inputs=x,
|
||||||
|
@ -1589,6 +1635,7 @@ class Base(nn.Module):
|
||||||
output_attentions = output_attentions,
|
output_attentions = output_attentions,
|
||||||
output_hidden_states = output_hidden_states,
|
output_hidden_states = output_hidden_states,
|
||||||
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
|
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
|
||||||
|
timesteps=timesteps,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = output.logits
|
logits = output.logits
|
||||||
|
|
|
@ -128,6 +128,10 @@ class NAR(Base):
|
||||||
token_dropout_error = self.config.experimental.token_dropout_error
|
token_dropout_error = self.config.experimental.token_dropout_error
|
||||||
# RVQ levels to apply token dropout on
|
# RVQ levels to apply token dropout on
|
||||||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||||
|
# CFG
|
||||||
|
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
||||||
|
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
||||||
|
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
|
||||||
# implicitly set it to all levels
|
# implicitly set it to all levels
|
||||||
if not token_dropout_rvq_levels:
|
if not token_dropout_rvq_levels:
|
||||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||||
|
@ -150,12 +154,10 @@ class NAR(Base):
|
||||||
|
|
||||||
# trim resps to only contain all levels below the target level
|
# trim resps to only contain all levels below the target level
|
||||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||||
|
# empty string for CFG
|
||||||
# tensor to cat for RVQ level 0
|
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||||
text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16)
|
|
||||||
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
|
|
||||||
# I hate python's value/reference semantics so much
|
# I hate python's value/reference semantics so much
|
||||||
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
for i, quant_level, text, resps, proms, task in zip(range(batch_size), quant_levels, text_list, resps_list, proms_list, task_list):
|
||||||
# cap quant_level if it exceeds its corresponding resp/prom
|
# cap quant_level if it exceeds its corresponding resp/prom
|
||||||
if quant_level >= resps.shape[-1]:
|
if quant_level >= resps.shape[-1]:
|
||||||
quant_levels[i] = resps.shape[-1] - 1
|
quant_levels[i] = resps.shape[-1] - 1
|
||||||
|
@ -193,6 +195,24 @@ class NAR(Base):
|
||||||
#resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
#resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||||
...
|
...
|
||||||
|
|
||||||
|
# apply CFG (should probably only apply to NAR quant level 0)
|
||||||
|
if task not in text_task:
|
||||||
|
drop_text = False
|
||||||
|
drop_audio = False
|
||||||
|
|
||||||
|
if random.random() < cfg_prom_dropout_p:
|
||||||
|
drop_audio = True
|
||||||
|
|
||||||
|
if random.random() < cfg_cond_dropout_p:
|
||||||
|
drop_audio = True
|
||||||
|
drop_text = True
|
||||||
|
|
||||||
|
if drop_text:
|
||||||
|
text_list[i] = text_start_stop_sequence
|
||||||
|
|
||||||
|
if drop_audio:
|
||||||
|
proms_list[i] = None
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
|
@ -225,7 +245,7 @@ class NAR(Base):
|
||||||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
|
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
|
||||||
|
|
||||||
# initial condition
|
# initial condition
|
||||||
len_list = [ min(l, 75*3) for l in len_list ]
|
len_list = [ clamp(1, 75*3, l) for l in len_list ]
|
||||||
metrics = []
|
metrics = []
|
||||||
|
|
||||||
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
|
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
|
||||||
|
@ -279,6 +299,10 @@ class NAR(Base):
|
||||||
noise_p = math.cos( start_noise * math.pi * 0.5 )
|
noise_p = math.cos( start_noise * math.pi * 0.5 )
|
||||||
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device )
|
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device )
|
||||||
|
|
||||||
|
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||||
|
null_prom = None
|
||||||
|
cfg_strength = 1.0
|
||||||
|
|
||||||
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
|
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
|
||||||
# anneal temperature
|
# anneal temperature
|
||||||
temperature = starting_temperature * (steps_until_x0 / max_steps)
|
temperature = starting_temperature * (steps_until_x0 / max_steps)
|
||||||
|
@ -294,6 +318,7 @@ class NAR(Base):
|
||||||
is_masked = input_ids == self.stop_token
|
is_masked = input_ids == self.stop_token
|
||||||
# setup inputs
|
# setup inputs
|
||||||
resps_list = [ input_ids ]
|
resps_list = [ input_ids ]
|
||||||
|
|
||||||
inputs = _super.inputs(
|
inputs = _super.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
|
@ -308,11 +333,29 @@ class NAR(Base):
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
layer_skip_variables=sampling_layer_skip_variables,
|
layer_skip_variables=sampling_layer_skip_variables,
|
||||||
)
|
)
|
||||||
|
if cfg_strength > 0:
|
||||||
|
null_inputs = _super.inputs(
|
||||||
|
text_list=[ null_text ],
|
||||||
|
proms_list=[ null_prom ],
|
||||||
|
resps_list=resps_list,
|
||||||
|
lang_list=lang_list,
|
||||||
|
tone_list=tone_list,
|
||||||
|
time_list=[ timestep ],
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
)
|
||||||
|
null_output = _super.forward(
|
||||||
|
inputs=null_inputs,
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
layer_skip_variables=sampling_layer_skip_variables,
|
||||||
|
)
|
||||||
|
logits = [ logits + ( logits - null_logits ) * cfg_strength for logits, null_logits in zip(output.logits, null_output.logits) ]
|
||||||
|
else:
|
||||||
|
logits = output.logits
|
||||||
|
|
||||||
# sample with sampler settings
|
# sample with sampler settings
|
||||||
sampling_top_p = 0.9
|
sampling_top_p = 0.9
|
||||||
filtered_sampled = _super.sample(
|
filtered_sampled = _super.sample(
|
||||||
logits=output.logits,
|
logits=logits,
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
|
|
||||||
|
@ -328,7 +371,7 @@ class NAR(Base):
|
||||||
|
|
||||||
# retrieves unfiltered logits
|
# retrieves unfiltered logits
|
||||||
unfiltered_sampled = _super.sample(
|
unfiltered_sampled = _super.sample(
|
||||||
logits=output.logits,
|
logits=logits,
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
|
@ -504,7 +547,6 @@ def example_usage():
|
||||||
|
|
||||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||||
|
|
||||||
|
|
||||||
text_list = [
|
text_list = [
|
||||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||||
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
|
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user