set option to set training masking ratio (I don't think for tts a fixed masking ratio is beneficial since the magic of the AR+NAR is being able to still reference the prior sequence of tokens for predicting things)
This commit is contained in:
parent
88d840218d
commit
069b27570f
|
@ -261,7 +261,7 @@ class ModelExperimentalSettings:
|
||||||
masking_train_p: float = 0.0 # odds of training with masking
|
masking_train_p: float = 0.0 # odds of training with masking
|
||||||
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
||||||
|
|
||||||
masking_ratio_fixed: bool = True # this sets the masking ratio to a fixed 80%
|
masking_ratio: str | float = 0.0 # sets a masking ratio, "random" will randomly pick
|
||||||
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
|
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
|
||||||
|
|
||||||
# classifier-free guidance shit
|
# classifier-free guidance shit
|
||||||
|
@ -981,6 +981,9 @@ class Config(BaseConfig):
|
||||||
|
|
||||||
if "p_len_train" in model["experimental"]:
|
if "p_len_train" in model["experimental"]:
|
||||||
del model["experimental"]["p_len_train"]
|
del model["experimental"]["p_len_train"]
|
||||||
|
|
||||||
|
if "masking_ratio_fixed" in model["experimental"]:
|
||||||
|
del model["experimental"]["masking_ratio_fixed"]
|
||||||
|
|
||||||
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
|
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
|
||||||
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]
|
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]
|
||||||
|
|
|
@ -55,16 +55,10 @@ def main():
|
||||||
parser.add_argument("--preamble", type=str, default=None)
|
parser.add_argument("--preamble", type=str, default=None)
|
||||||
parser.add_argument("--output-filename", type=str, default="index.html")
|
parser.add_argument("--output-filename", type=str, default="index.html")
|
||||||
|
|
||||||
parser.add_argument("--language", type=str, default="en")
|
|
||||||
|
|
||||||
parser.add_argument("--language", type=str, default="en")
|
parser.add_argument("--language", type=str, default="en")
|
||||||
parser.add_argument("--task", type=str, default="tts")
|
parser.add_argument("--task", type=str, default="tts")
|
||||||
parser.add_argument("--out-path", type=Path, default=None)
|
parser.add_argument("--out-path", type=Path, default=None)
|
||||||
|
|
||||||
parser.add_argument("--yaml", type=Path, default=None)
|
|
||||||
parser.add_argument("--model", type=Path, default=None)
|
|
||||||
parser.add_argument("--lora", type=Path, default=None)
|
|
||||||
|
|
||||||
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
|
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||||
parser.add_argument("--max-steps", type=int, default=25)
|
parser.add_argument("--max-steps", type=int, default=25)
|
||||||
parser.add_argument("--max-levels", type=int, default=7)
|
parser.add_argument("--max-levels", type=int, default=7)
|
||||||
|
@ -362,7 +356,6 @@ def main():
|
||||||
text=text,
|
text=text,
|
||||||
references=[prompt],
|
references=[prompt],
|
||||||
language=language,
|
language=language,
|
||||||
input_prompt_length=args.input_prompt_length,
|
|
||||||
seed=seed,
|
seed=seed,
|
||||||
tqdm=False,
|
tqdm=False,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
|
|
|
@ -70,6 +70,7 @@ class AR_NAR(Base):
|
||||||
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
|
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
|
||||||
# rate to train RVQ level AR-ly or NAR-ly
|
# rate to train RVQ level AR-ly or NAR-ly
|
||||||
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
|
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
|
||||||
|
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else "random"
|
||||||
# force set mask training
|
# force set mask training
|
||||||
if "len" not in self.capabilities:
|
if "len" not in self.capabilities:
|
||||||
masking_train_p = 0.0
|
masking_train_p = 0.0
|
||||||
|
@ -108,6 +109,10 @@ class AR_NAR(Base):
|
||||||
#p = math.acos(r) / (math.pi * 0.5)
|
#p = math.acos(r) / (math.pi * 0.5)
|
||||||
#timesteps[i] = 1.0 - clamp(p, 0.0, 1.0)
|
#timesteps[i] = 1.0 - clamp(p, 0.0, 1.0)
|
||||||
timesteps[i] = random.random()
|
timesteps[i] = random.random()
|
||||||
|
|
||||||
|
# instead make it between [0.2, 0.8]
|
||||||
|
if masking_ratio == "rand":
|
||||||
|
timesteps[i] = (timesteps[i] * 0.6) + 0.2
|
||||||
|
|
||||||
# trim resps to only contain all levels below the target level
|
# trim resps to only contain all levels below the target level
|
||||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||||
|
|
|
@ -436,7 +436,7 @@ class Base(nn.Module):
|
||||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||||
|
|
||||||
masking_ratio_fixed = self.config.experimental.masking_ratio_fixed if self.config is not None else False
|
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
||||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||||
|
|
||||||
layerskip = self.config.experimental.layerskip if self.config is not None else False
|
layerskip = self.config.experimental.layerskip if self.config is not None else False
|
||||||
|
@ -481,7 +481,7 @@ class Base(nn.Module):
|
||||||
self.interleave = interleave
|
self.interleave = interleave
|
||||||
self.layerskip = layerskip
|
self.layerskip = layerskip
|
||||||
self.inject_timestep_embedding = False # results in bad output
|
self.inject_timestep_embedding = False # results in bad output
|
||||||
self.masking_ratio_fixed = masking_ratio_fixed
|
self.masking_ratio = masking_ratio
|
||||||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||||
|
|
||||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
|
@ -537,7 +537,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# experimental NAR-only mode
|
# experimental NAR-only mode
|
||||||
self.len_emb = Embedding(11, d_model)
|
self.len_emb = Embedding(11, d_model)
|
||||||
self.time_emb = TimeEmbedding(d_model) # if not masking_ratio_fixed else None
|
self.time_emb = TimeEmbedding(d_model) # if not masking_ratio else None
|
||||||
|
|
||||||
if attention_backend == "auto":
|
if attention_backend == "auto":
|
||||||
attention_backend = "sdpa"
|
attention_backend = "sdpa"
|
||||||
|
@ -840,7 +840,6 @@ class Base(nn.Module):
|
||||||
state = None,
|
state = None,
|
||||||
|
|
||||||
layer_skip_lambda = None,
|
layer_skip_lambda = None,
|
||||||
timesteps = None,
|
|
||||||
|
|
||||||
output_attentions = False,
|
output_attentions = False,
|
||||||
output_hidden_states = False,
|
output_hidden_states = False,
|
||||||
|
@ -871,9 +870,6 @@ class Base(nn.Module):
|
||||||
if self.layerskip and layer_skip_lambda is not None:
|
if self.layerskip and layer_skip_lambda is not None:
|
||||||
kwargs["layer_skip_lambda"] = layer_skip_lambda
|
kwargs["layer_skip_lambda"] = layer_skip_lambda
|
||||||
|
|
||||||
if "len" in self.capabilities and timesteps is not None:
|
|
||||||
kwargs["timesteps"] = timesteps
|
|
||||||
|
|
||||||
output = self.model(**kwargs)
|
output = self.model(**kwargs)
|
||||||
x = output["last_hidden_state"]
|
x = output["last_hidden_state"]
|
||||||
|
|
||||||
|
@ -1012,13 +1008,18 @@ class Base(nn.Module):
|
||||||
if timestep is not None:
|
if timestep is not None:
|
||||||
# force set to use this classifier level
|
# force set to use this classifier level
|
||||||
classifier_level = "NAR:0:0"
|
classifier_level = "NAR:0:0"
|
||||||
# a paper said to use a fixed masking ratio for training
|
|
||||||
p = 0.8
|
|
||||||
# store timestep information
|
# store timestep information
|
||||||
if not self.masking_ratio_fixed:
|
if self.masking_ratio in ["random", "rand"]:
|
||||||
# cosine scheduled timestep => masking ratio
|
# cosine scheduled timestep => masking ratio
|
||||||
p = math.cos(timestep * math.pi * 0.5)
|
p = math.cos(timestep * math.pi * 0.5)
|
||||||
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
# I don't think is is necessary as the timestep is encoded in the sequence by the number of masked tokens, probably.
|
||||||
|
if self.inject_timestep_embedding:
|
||||||
|
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||||
|
else:
|
||||||
|
# a paper said to use a fixed masking ratio of 0.8 for training
|
||||||
|
# ...but I want to make it user adjustable
|
||||||
|
p = self.masking_ratio
|
||||||
|
|
||||||
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
|
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
|
||||||
if self.training:
|
if self.training:
|
||||||
dropout_mask = _dropout_mask( resps_list[i], p )
|
dropout_mask = _dropout_mask( resps_list[i], p )
|
||||||
|
@ -1597,12 +1598,6 @@ class Base(nn.Module):
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||||
|
|
||||||
if self.inject_timestep_embedding:
|
|
||||||
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
|
||||||
timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ]
|
|
||||||
else:
|
|
||||||
timesteps = []
|
|
||||||
|
|
||||||
output = self._forward(
|
output = self._forward(
|
||||||
inputs=x,
|
inputs=x,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
|
@ -1611,7 +1606,6 @@ class Base(nn.Module):
|
||||||
output_attentions = output_attentions,
|
output_attentions = output_attentions,
|
||||||
output_hidden_states = output_hidden_states,
|
output_hidden_states = output_hidden_states,
|
||||||
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
|
layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None,
|
||||||
timesteps=timesteps,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = output.logits
|
logits = output.logits
|
||||||
|
|
Loading…
Reference in New Issue
Block a user