diff --git a/vall_e/config.py b/vall_e/config.py index b7b8e14..6c11fef 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -267,6 +267,7 @@ class ModelExperimentalSettings: ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence noncausal_masks: bool = False # to correct an oversight with Llama always using causal masks...... + classifiers_bias: bool = True # ugh # classifier-free guidance training settings cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training @@ -477,7 +478,7 @@ class Hyperparameters: teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation teacher_temperature: float = 1.0 teacher_loss_fn: str = "mse" # kl | mse, use either kl_div or mse_loss (most implementations use kl, some literature says you can use mse) - + @dataclass() class Evaluation: batch_size: int = 64 # number of samples per batch during eval / val diff --git a/vall_e/export.py b/vall_e/export.py index 9b046f3..16dd39b 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -74,9 +74,7 @@ def convert_to_hf( state_dict, config = None, save_path = None ): embedding = torch.nn.Embedding( n_tokens, model_dim ) classifier = torch.nn.Linear( model_dim, n_tokens ) - #embedding.weight.requires_grad = False - #classifier.weight.requires_grad = False - #classifier.bias.requires_grad = False + # to-do: ignore classifier for RVQ level 7 # inject text tokens token_start = 0 @@ -192,9 +190,21 @@ def convert_to_hf( state_dict, config = None, save_path = None ): out_dir.mkdir(parents=True, exist_ok=True) # write weights torch_save( model_dict, out_dir / "model.safetensors" ) - # write vocab.json + # write tokenizer.json tokenizer['model']['vocab'] |= tokenizer_vocab json_write(tokenizer, out_dir / "tokenizer.json", pretty=True) + # write tokenizer_config.json + json_write({ + "added_tokens": tokenizer['added_tokens'], + "bos_token": "", + "eos_token": "", + "clean_up_tokenization_spaces": True, + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "tokenizer_class": "PreTrainedTokenizerFast" + }, out_dir / "tokenizer_config.json", pretty=True) # write config.json json_write({ "architectures": [ diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 9776219..9747507 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -251,10 +251,11 @@ class Classifiers(nn.Module): self, l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token) token_dim: int, # dimensionality of the embedding - l_names: list[str] | None = None, # list of names to map to each classifier + l_names: list[str] | None = None, # list of names to map to each classifier, + bias: bool = True, ): super().__init__() - self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens]) + self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for n_tokens in l_tokens]) self.names = l_names def indices( @@ -446,6 +447,7 @@ class Base(nn.Module): unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True interleave = self.config.experimental.interleave if self.config is not None else False noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False + classifiers_bias = self.config.experimental.classifiers_bias if self.config is not None else False masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False @@ -781,7 +783,7 @@ class Base(nn.Module): self.metrics = None else: self.classifier = None - self.classifiers = Classifiers( classifier_l_tokens, d_model, l_names=classifier_l_names ) + self.classifiers = Classifiers( classifier_l_tokens, d_model, l_names=classifier_l_names, bias=classifiers_bias ) self.accuracy_metric = None self.precision_metric = None self.metrics = Metrics( classifier_l_tokens )