set option to set training masking ratio (I don't think for tts a fixed masking ratio is beneficial since the magic of the AR+NAR is being able to still reference the prior sequence of tokens for predicting things)

This commit is contained in:
mrq 2024-11-17 17:04:07 -06:00
parent 88d840218d
commit 069b27570f
4 changed files with 21 additions and 26 deletions

View File

@ -261,7 +261,7 @@ class ModelExperimentalSettings:
masking_train_p: float = 0.0 # odds of training with masking masking_train_p: float = 0.0 # odds of training with masking
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
masking_ratio_fixed: bool = True # this sets the masking ratio to a fixed 80% masking_ratio: str | float = 0.0 # sets a masking ratio, "random" will randomly pick
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
# classifier-free guidance shit # classifier-free guidance shit
@ -981,6 +981,9 @@ class Config(BaseConfig):
if "p_len_train" in model["experimental"]: if "p_len_train" in model["experimental"]:
del model["experimental"]["p_len_train"] del model["experimental"]["p_len_train"]
if "masking_ratio_fixed" in model["experimental"]:
del model["experimental"]["masking_ratio_fixed"]
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ] self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ] self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]

View File

@ -55,16 +55,10 @@ def main():
parser.add_argument("--preamble", type=str, default=None) parser.add_argument("--preamble", type=str, default=None)
parser.add_argument("--output-filename", type=str, default="index.html") parser.add_argument("--output-filename", type=str, default="index.html")
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--language", type=str, default="en") parser.add_argument("--language", type=str, default="en")
parser.add_argument("--task", type=str, default="tts") parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--model", type=Path, default=None)
parser.add_argument("--lora", type=Path, default=None)
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second) parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-steps", type=int, default=25) parser.add_argument("--max-steps", type=int, default=25)
parser.add_argument("--max-levels", type=int, default=7) parser.add_argument("--max-levels", type=int, default=7)
@ -362,7 +356,6 @@ def main():
text=text, text=text,
references=[prompt], references=[prompt],
language=language, language=language,
input_prompt_length=args.input_prompt_length,
seed=seed, seed=seed,
tqdm=False, tqdm=False,
**sampling_kwargs, **sampling_kwargs,

View File

@ -70,6 +70,7 @@ class AR_NAR(Base):
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0 cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
# rate to train RVQ level AR-ly or NAR-ly # rate to train RVQ level AR-ly or NAR-ly
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5 masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else "random"
# force set mask training # force set mask training
if "len" not in self.capabilities: if "len" not in self.capabilities:
masking_train_p = 0.0 masking_train_p = 0.0
@ -108,6 +109,10 @@ class AR_NAR(Base):
#p = math.acos(r) / (math.pi * 0.5) #p = math.acos(r) / (math.pi * 0.5)
#timesteps[i] = 1.0 - clamp(p, 0.0, 1.0) #timesteps[i] = 1.0 - clamp(p, 0.0, 1.0)
timesteps[i] = random.random() timesteps[i] = random.random()
# instead make it between [0.2, 0.8]
if masking_ratio == "rand":
timesteps[i] = (timesteps[i] * 0.6) + 0.2
# trim resps to only contain all levels below the target level # trim resps to only contain all levels below the target level
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]

View File

@ -436,7 +436,7 @@ class Base(nn.Module):
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
interleave = self.config.experimental.interleave if self.config is not None else False interleave = self.config.experimental.interleave if self.config is not None else False
masking_ratio_fixed = self.config.experimental.masking_ratio_fixed if self.config is not None else False masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
layerskip = self.config.experimental.layerskip if self.config is not None else False layerskip = self.config.experimental.layerskip if self.config is not None else False
@ -481,7 +481,7 @@ class Base(nn.Module):
self.interleave = interleave self.interleave = interleave
self.layerskip = layerskip self.layerskip = layerskip
self.inject_timestep_embedding = False # results in bad output self.inject_timestep_embedding = False # results in bad output
self.masking_ratio_fixed = masking_ratio_fixed self.masking_ratio = masking_ratio
self.ignore_inputs_for_loss = ignore_inputs_for_loss self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.text_emb = Embedding(n_text_tokens, d_model) self.text_emb = Embedding(n_text_tokens, d_model)
@ -537,7 +537,7 @@ class Base(nn.Module):
# experimental NAR-only mode # experimental NAR-only mode
self.len_emb = Embedding(11, d_model) self.len_emb = Embedding(11, d_model)
self.time_emb = TimeEmbedding(d_model) # if not masking_ratio_fixed else None self.time_emb = TimeEmbedding(d_model) # if not masking_ratio else None
if attention_backend == "auto": if attention_backend == "auto":
attention_backend = "sdpa" attention_backend = "sdpa"
@ -840,7 +840,6 @@ class Base(nn.Module):
state = None, state = None,
layer_skip_lambda = None, layer_skip_lambda = None,
timesteps = None,
output_attentions = False, output_attentions = False,
output_hidden_states = False, output_hidden_states = False,
@ -871,9 +870,6 @@ class Base(nn.Module):
if self.layerskip and layer_skip_lambda is not None: if self.layerskip and layer_skip_lambda is not None:
kwargs["layer_skip_lambda"] = layer_skip_lambda kwargs["layer_skip_lambda"] = layer_skip_lambda
if "len" in self.capabilities and timesteps is not None:
kwargs["timesteps"] = timesteps
output = self.model(**kwargs) output = self.model(**kwargs)
x = output["last_hidden_state"] x = output["last_hidden_state"]
@ -1012,13 +1008,18 @@ class Base(nn.Module):
if timestep is not None: if timestep is not None:
# force set to use this classifier level # force set to use this classifier level
classifier_level = "NAR:0:0" classifier_level = "NAR:0:0"
# a paper said to use a fixed masking ratio for training
p = 0.8
# store timestep information # store timestep information
if not self.masking_ratio_fixed: if self.masking_ratio in ["random", "rand"]:
# cosine scheduled timestep => masking ratio # cosine scheduled timestep => masking ratio
p = math.cos(timestep * math.pi * 0.5) p = math.cos(timestep * math.pi * 0.5)
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) ) # I don't think is is necessary as the timestep is encoded in the sequence by the number of masked tokens, probably.
if self.inject_timestep_embedding:
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
else:
# a paper said to use a fixed masking ratio of 0.8 for training
# ...but I want to make it user adjustable
p = self.masking_ratio
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided) # store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
if self.training: if self.training:
dropout_mask = _dropout_mask( resps_list[i], p ) dropout_mask = _dropout_mask( resps_list[i], p )
@ -1597,12 +1598,6 @@ class Base(nn.Module):
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" ) classifier_levels = self.get_input( inputs, name="classifier_level" )
if self.inject_timestep_embedding:
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
else:
timesteps = []
output = self._forward( output = self._forward(
inputs=x, inputs=x,
mask=mask, mask=mask,
@ -1611,7 +1606,6 @@ class Base(nn.Module):
output_attentions = output_attentions, output_attentions = output_attentions,
output_hidden_states = output_hidden_states, output_hidden_states = output_hidden_states,
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None, layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
timesteps=timesteps,
) )
logits = output.logits logits = output.logits