This better work
This commit is contained in:
parent
8b3d1cf70a
commit
c6a38693a2
|
@ -56,7 +56,13 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
|
|||
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
|
||||
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
|
||||
|
||||
To-do: fill out this more when it works.
|
||||
To-do: fill out this more when it works. Getting this to work is a huge pain.
|
||||
* Some masked transformers do not "inject" any timestep information (Text-To-Image Muse as far as I can tell)
|
||||
* Others "expose" it by applying a timestep embedding after pre/post attention normalization
|
||||
* Except F5-TTS only does this pre for its DiTS, but not UnetT
|
||||
* MaskGCT does it both pre and post
|
||||
* the test trainier actually degrades the output immensely when doing this
|
||||
* I'm sure I've seen a masked transformer not have CFG, but most of them seem to do (and all seem to be poorly documentated on specifically how its doing it for my dumb brain)
|
||||
|
||||
## Embeddings
|
||||
|
||||
|
|
|
@ -258,6 +258,11 @@ class ModelExperimentalSettings:
|
|||
len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
||||
# to-to: just incorporate this as a task instead
|
||||
|
||||
# classifier-free guidance shit
|
||||
cfg_cond_dropout_p: float = 0.2 # probability to drop out text and audio during training
|
||||
cfg_text_dropout_p: float = 0.0 # probability to drop out input audio prompt during training
|
||||
cfg_prom_dropout_p: float = 0.3 # probability to drop out input audio prompt during training
|
||||
|
||||
layerskip: bool = False # layerskip compatible model (or training for)
|
||||
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
|
||||
layerskip_r: int = 2 # number of layers to factor into early-exit loss calc
|
||||
|
|
|
@ -30,7 +30,7 @@ except Exception as e:
|
|||
pass
|
||||
|
||||
try:
|
||||
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM
|
||||
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaDecoderLayer, LlamaDecoderLayer_Adapted, LlamaForCausalLM
|
||||
AVAILABLE_ARCHES.append("llama")
|
||||
except Exception as e:
|
||||
ERROR_ARCHES["llama"] = e
|
||||
|
|
|
@ -13,7 +13,7 @@ from transformers.cache_utils import Cache
|
|||
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -304,15 +304,125 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
|
||||
return attn_output, attn_scores, past_key_value
|
||||
|
||||
class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
|
||||
# apply timestep embedding with attention norm
|
||||
# I don't have a concrete idea on how helpful this is, as:
|
||||
# * F5-TTS's UNetT implementation doesn't do this
|
||||
# * F5-TTS's DiT does this, but only for pre-attention normalization
|
||||
# * MaskGCT does this for both
|
||||
# * Muse doesn't do this, but instead appends the timestep embedding
|
||||
def weigh_by_timestep(
|
||||
self,
|
||||
hidden_states,
|
||||
timesteps,
|
||||
):
|
||||
if timesteps is None:
|
||||
return hidden_states
|
||||
|
||||
for i, timestep in enumerate( timesteps ):
|
||||
# invalid
|
||||
if not isinstance( timestep, torch.Tensor ):
|
||||
continue
|
||||
hidden_states[i] *= timestep
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
timesteps: Optional[list] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.weigh_by_timestep( hidden_states, timesteps )
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.weigh_by_timestep( hidden_states, timesteps )
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
class LlamaModel_Adapted(LlamaModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, config, *args, **kwargs):
|
||||
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1)
|
||||
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
|
||||
self.early_exit_r = kwargs.pop("early_exit_r", 2)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
#super().__init__(*args, **kwargs)
|
||||
super(LlamaModel, self).__init__(config)
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.layers_n = config.num_hidden_layers
|
||||
|
||||
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[LlamaDecoderLayer_Adapted(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
self.layers_n = len(self.layers)
|
||||
def dropoff_layer( self, l ):
|
||||
if not self.training:
|
||||
return False
|
||||
|
@ -360,6 +470,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
|
||||
layer_skip_lambda = None,
|
||||
timesteps: Optional[list] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -377,8 +488,10 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
)
|
||||
use_cache = False
|
||||
|
||||
"""
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
"""
|
||||
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = False
|
||||
|
@ -430,6 +543,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
timesteps,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
|
@ -441,6 +555,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
if not self.dropoff_layer( l ):
|
||||
|
@ -469,6 +584,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
|
|
|
@ -53,7 +53,9 @@ def _dropout_mask( input, p=None ):
|
|||
t = random.random()
|
||||
p = math.cos(t * math.pi * 0.5)
|
||||
|
||||
return torch.tensor( [ random.random() < p for _ in range( input.shape[0] ) ], dtype=torch.bool, device=input.device )
|
||||
seq = [ random.random() < p for _ in range( input.shape[0] ) ]
|
||||
mask = torch.tensor( seq, dtype=torch.bool, device=input.device )
|
||||
return mask
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
@ -646,7 +648,8 @@ class Base(nn.Module):
|
|||
use_reentrant=False
|
||||
))
|
||||
elif self.arch_type == "llama":
|
||||
LlamaClass = LlamaModel_Adapted if self.layerskip else LlamaModel
|
||||
LlamaClass = LlamaModel_Adapted if (self.layerskip or "len" in self.capabilities) else LlamaModel
|
||||
|
||||
if n_experts <= 1:
|
||||
self.model = LlamaClass(LlamaConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
|
@ -664,8 +667,16 @@ class Base(nn.Module):
|
|||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
))
|
||||
|
||||
# replace with desired attention
|
||||
if attention_backend not in HF_ATTENTIONS:
|
||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
||||
|
||||
# replace with modified Llama
|
||||
"""
|
||||
if "len" in self.capabilities:
|
||||
self.model = ml.replace_attention( self.model, klass=LlamaDecoderLayer_Adapted, target=LlamaDecoderLayer, mode=attention_backend )
|
||||
"""
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
vocab_size =n_resp_tokens,
|
||||
|
@ -866,6 +877,7 @@ class Base(nn.Module):
|
|||
state = None,
|
||||
|
||||
layer_skip_lambda = None,
|
||||
timesteps = None,
|
||||
|
||||
output_attentions = False,
|
||||
output_hidden_states = False,
|
||||
|
@ -896,6 +908,9 @@ class Base(nn.Module):
|
|||
if self.layerskip and layer_skip_lambda is not None:
|
||||
kwargs["layer_skip_lambda"] = layer_skip_lambda
|
||||
|
||||
if "len" in self.capabilities and timesteps is not None:
|
||||
kwargs["timesteps"] = timesteps
|
||||
|
||||
output = self.model(**kwargs)
|
||||
x = output["last_hidden_state"]
|
||||
|
||||
|
@ -1036,8 +1051,12 @@ class Base(nn.Module):
|
|||
p = math.cos(t * math.pi * 0.5)
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
||||
|
||||
inputs[i].append( ("timestep", torch.tensor(t, device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
inputs[i].append( ("timestep", torch.tensor([t], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
else:
|
||||
# in the event it's needed (it makes shit sound worse)
|
||||
#inputs[i].append( ("timestep", torch.tensor([1.0], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
...
|
||||
|
||||
# Audio length prediction task
|
||||
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
||||
|
@ -1088,7 +1107,6 @@ class Base(nn.Module):
|
|||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
else:
|
||||
raise Exception(f'Unrecognized task: {task_type}')
|
||||
|
||||
return inputs
|
||||
|
||||
def inputs_to_embeddings(
|
||||
|
@ -1139,13 +1157,10 @@ class Base(nn.Module):
|
|||
for name, input in batch_input:
|
||||
if name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
elif name == "timestep":
|
||||
timestep = input
|
||||
|
||||
for name, input in batch_input:
|
||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||
embedding = None
|
||||
|
||||
# is already an embedding
|
||||
if name == "task":
|
||||
# noop
|
||||
|
@ -1162,6 +1177,10 @@ class Base(nn.Module):
|
|||
embedding = self.langs_emb( input )
|
||||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
"""
|
||||
if proms is None:
|
||||
continue
|
||||
"""
|
||||
# to-do: probably insert separators if task requires it?
|
||||
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
||||
elif name == "tone" and self.tones_emb is not None:
|
||||
|
@ -1179,13 +1198,11 @@ class Base(nn.Module):
|
|||
# if training NAR-len RVQ level 0
|
||||
elif "len" in self.capabilities and quant_level == 0 and dropout_mask is not None:
|
||||
embedding = self.resps_emb(
|
||||
# if masked use masked token, else original token
|
||||
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
|
||||
t_emb = self.time_emb( timestep )
|
||||
embedding += t_emb
|
||||
# cheat-y way to handle performing STT across all levels
|
||||
elif task_type in summed_embeddings_task:
|
||||
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
||||
|
@ -1227,18 +1244,36 @@ class Base(nn.Module):
|
|||
continue
|
||||
|
||||
embedding[i] = self.dropout_token
|
||||
|
||||
elif name == "timestep" and self.time_emb is not None:
|
||||
embedding = self.time_emb( input )
|
||||
elif name == "len" and self.len_emb is not None:
|
||||
embedding = self.len_emb( input )
|
||||
else:
|
||||
# should probably raise an exception so things aren't processed silently
|
||||
continue
|
||||
|
||||
batch.append(embedding)
|
||||
|
||||
x_list.append( _join( batch, self.sep ) )
|
||||
|
||||
return x_list
|
||||
|
||||
# get an attribute from a given input list
|
||||
def get_input(
|
||||
self,
|
||||
inputs,
|
||||
name,
|
||||
at=None,
|
||||
):
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
if at is not None and batch_index != batch_index:
|
||||
continue
|
||||
|
||||
for n, input in batch_input:
|
||||
if n == name:
|
||||
return input
|
||||
return None
|
||||
|
||||
# creates position ids from a given input list
|
||||
# if not unified_position_ids, then each input segment will have its own sequence
|
||||
def inputs_to_position_ids(
|
||||
|
@ -1262,7 +1297,7 @@ class Base(nn.Module):
|
|||
return 1
|
||||
|
||||
# a mask
|
||||
if name in ["dropout_mask", "timestep"]:
|
||||
if name in ["dropout_mask"]:
|
||||
return 0
|
||||
|
||||
# list of tokens
|
||||
|
@ -1342,6 +1377,7 @@ class Base(nn.Module):
|
|||
elif name == "resp":
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
# if mask use original token, else ignore
|
||||
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
||||
elif self.interleave:
|
||||
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
||||
|
@ -1350,6 +1386,8 @@ class Base(nn.Module):
|
|||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||
else:
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level] )
|
||||
elif name == "timestep":
|
||||
target.append( torch.tensor([self.ignore_index], device=input.device) )
|
||||
elif name in ["text", "quant_level", "lang", "tone", "len"]:
|
||||
target.append( input )
|
||||
|
||||
|
@ -1579,7 +1617,15 @@ class Base(nn.Module):
|
|||
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
|
||||
classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
||||
"""
|
||||
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
||||
#timesteps = [ inputs[i][-1] if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
||||
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
||||
"""
|
||||
timesteps = []
|
||||
|
||||
classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
output = self._forward(
|
||||
inputs=x,
|
||||
|
@ -1589,6 +1635,7 @@ class Base(nn.Module):
|
|||
output_attentions = output_attentions,
|
||||
output_hidden_states = output_hidden_states,
|
||||
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
logits = output.logits
|
||||
|
|
|
@ -128,6 +128,10 @@ class NAR(Base):
|
|||
token_dropout_error = self.config.experimental.token_dropout_error
|
||||
# RVQ levels to apply token dropout on
|
||||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||
# CFG
|
||||
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
||||
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
||||
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
|
||||
# implicitly set it to all levels
|
||||
if not token_dropout_rvq_levels:
|
||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||
|
@ -150,12 +154,10 @@ class NAR(Base):
|
|||
|
||||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
|
||||
# tensor to cat for RVQ level 0
|
||||
text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16)
|
||||
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
|
||||
# empty string for CFG
|
||||
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||
# I hate python's value/reference semantics so much
|
||||
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
||||
for i, quant_level, text, resps, proms, task in zip(range(batch_size), quant_levels, text_list, resps_list, proms_list, task_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_level >= resps.shape[-1]:
|
||||
quant_levels[i] = resps.shape[-1] - 1
|
||||
|
@ -193,6 +195,24 @@ class NAR(Base):
|
|||
#resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||
...
|
||||
|
||||
# apply CFG (should probably only apply to NAR quant level 0)
|
||||
if task not in text_task:
|
||||
drop_text = False
|
||||
drop_audio = False
|
||||
|
||||
if random.random() < cfg_prom_dropout_p:
|
||||
drop_audio = True
|
||||
|
||||
if random.random() < cfg_cond_dropout_p:
|
||||
drop_audio = True
|
||||
drop_text = True
|
||||
|
||||
if drop_text:
|
||||
text_list[i] = text_start_stop_sequence
|
||||
|
||||
if drop_audio:
|
||||
proms_list[i] = None
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
|
@ -225,7 +245,7 @@ class NAR(Base):
|
|||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
|
||||
|
||||
# initial condition
|
||||
len_list = [ min(l, 75*3) for l in len_list ]
|
||||
len_list = [ clamp(1, 75*3, l) for l in len_list ]
|
||||
metrics = []
|
||||
|
||||
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
|
||||
|
@ -279,6 +299,10 @@ class NAR(Base):
|
|||
noise_p = math.cos( start_noise * math.pi * 0.5 )
|
||||
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device )
|
||||
|
||||
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||
null_prom = None
|
||||
cfg_strength = 1.0
|
||||
|
||||
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
|
||||
# anneal temperature
|
||||
temperature = starting_temperature * (steps_until_x0 / max_steps)
|
||||
|
@ -294,6 +318,7 @@ class NAR(Base):
|
|||
is_masked = input_ids == self.stop_token
|
||||
# setup inputs
|
||||
resps_list = [ input_ids ]
|
||||
|
||||
inputs = _super.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
|
@ -308,11 +333,29 @@ class NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
if cfg_strength > 0:
|
||||
null_inputs = _super.inputs(
|
||||
text_list=[ null_text ],
|
||||
proms_list=[ null_prom ],
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
time_list=[ timestep ],
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
null_output = _super.forward(
|
||||
inputs=null_inputs,
|
||||
quant_levels=quant_levels,
|
||||
layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
logits = [ logits + ( logits - null_logits ) * cfg_strength for logits, null_logits in zip(output.logits, null_output.logits) ]
|
||||
else:
|
||||
logits = output.logits
|
||||
|
||||
# sample with sampler settings
|
||||
sampling_top_p = 0.9
|
||||
filtered_sampled = _super.sample(
|
||||
logits=output.logits,
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
|
@ -328,7 +371,7 @@ class NAR(Base):
|
|||
|
||||
# retrieves unfiltered logits
|
||||
unfiltered_sampled = _super.sample(
|
||||
logits=output.logits,
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
temperature=0.0,
|
||||
|
@ -504,7 +547,6 @@ def example_usage():
|
|||
|
||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
|
||||
|
||||
text_list = [
|
||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
|
||||
|
|
Loading…
Reference in New Issue
Block a user