This commit is contained in:
mrq 2024-12-20 17:13:37 -06:00
parent d85273609e
commit 91caf00212
3 changed files with 21 additions and 8 deletions

View File

@ -267,6 +267,7 @@ class ModelExperimentalSettings:
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
noncausal_masks: bool = False # to correct an oversight with Llama always using causal masks...... noncausal_masks: bool = False # to correct an oversight with Llama always using causal masks......
classifiers_bias: bool = True # ugh
# classifier-free guidance training settings # classifier-free guidance training settings
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
@ -477,7 +478,7 @@ class Hyperparameters:
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
teacher_temperature: float = 1.0 teacher_temperature: float = 1.0
teacher_loss_fn: str = "mse" # kl | mse, use either kl_div or mse_loss (most implementations use kl, some literature says you can use mse) teacher_loss_fn: str = "mse" # kl | mse, use either kl_div or mse_loss (most implementations use kl, some literature says you can use mse)
@dataclass() @dataclass()
class Evaluation: class Evaluation:
batch_size: int = 64 # number of samples per batch during eval / val batch_size: int = 64 # number of samples per batch during eval / val

View File

@ -74,9 +74,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
embedding = torch.nn.Embedding( n_tokens, model_dim ) embedding = torch.nn.Embedding( n_tokens, model_dim )
classifier = torch.nn.Linear( model_dim, n_tokens ) classifier = torch.nn.Linear( model_dim, n_tokens )
#embedding.weight.requires_grad = False # to-do: ignore classifier for RVQ level 7
#classifier.weight.requires_grad = False
#classifier.bias.requires_grad = False
# inject text tokens # inject text tokens
token_start = 0 token_start = 0
@ -192,9 +190,21 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
# write weights # write weights
torch_save( model_dict, out_dir / "model.safetensors" ) torch_save( model_dict, out_dir / "model.safetensors" )
# write vocab.json # write tokenizer.json
tokenizer['model']['vocab'] |= tokenizer_vocab tokenizer['model']['vocab'] |= tokenizer_vocab
json_write(tokenizer, out_dir / "tokenizer.json", pretty=True) json_write(tokenizer, out_dir / "tokenizer.json", pretty=True)
# write tokenizer_config.json
json_write({
"added_tokens": tokenizer['added_tokens'],
"bos_token": "<bos>",
"eos_token": "</eos>",
"clean_up_tokenization_spaces": True,
"model_input_names": [
"input_ids",
"attention_mask"
],
"tokenizer_class": "PreTrainedTokenizerFast"
}, out_dir / "tokenizer_config.json", pretty=True)
# write config.json # write config.json
json_write({ json_write({
"architectures": [ "architectures": [

View File

@ -251,10 +251,11 @@ class Classifiers(nn.Module):
self, self,
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token) l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding token_dim: int, # dimensionality of the embedding
l_names: list[str] | None = None, # list of names to map to each classifier l_names: list[str] | None = None, # list of names to map to each classifier,
bias: bool = True,
): ):
super().__init__() super().__init__()
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens]) self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for n_tokens in l_tokens])
self.names = l_names self.names = l_names
def indices( def indices(
@ -446,6 +447,7 @@ class Base(nn.Module):
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
interleave = self.config.experimental.interleave if self.config is not None else False interleave = self.config.experimental.interleave if self.config is not None else False
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
classifiers_bias = self.config.experimental.classifiers_bias if self.config is not None else False
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
@ -781,7 +783,7 @@ class Base(nn.Module):
self.metrics = None self.metrics = None
else: else:
self.classifier = None self.classifier = None
self.classifiers = Classifiers( classifier_l_tokens, d_model, l_names=classifier_l_names ) self.classifiers = Classifiers( classifier_l_tokens, d_model, l_names=classifier_l_names, bias=classifiers_bias )
self.accuracy_metric = None self.accuracy_metric = None
self.precision_metric = None self.precision_metric = None
self.metrics = Metrics( classifier_l_tokens ) self.metrics = Metrics( classifier_l_tokens )