From d53038a9e477ea78f8f2a983e4b8879e55900f64 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 19 Jul 2024 15:33:31 -0500 Subject: [PATCH] actually have split classifiers working --- vall_e/config.py | 8 +++++--- vall_e/engines/__init__.py | 2 +- vall_e/models/base.py | 23 ++++++++++++++++++----- vall_e/samplers.py | 3 +++ 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index e1263a2..721b05d 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -200,17 +200,19 @@ class Dataset: def max_duration(self): return self.duration_range[1] +# collection of experimental variables that should not be tampered with unless you know what you're doing @dataclass() class ModelExperimentalSettings: hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal) - split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head + split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom) audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings kv_heads: int = 0 # MHA or GQA (for supported backends) p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely - rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range + rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary unified_position_ids: bool = True # False will generate position IDs partitioned for each section + tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing # I really need to clean this up @dataclass() @@ -230,7 +232,7 @@ class Model: dropout: float = 0.1 # adjustable dropout value #loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good loss_factors: dict = field(default_factory=lambda: {}) - capabilities: list = field(default_factory=lambda: ["ar", "nar"]) + capabilities: list = field(default_factory=lambda: ["ar", "nar"]) # + ["lang", "tone"] if you have your dataset labeled for such experimental: dict | ModelExperimentalSettings | None = None # experimental settings diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 93eb337..ebb1f38 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -57,7 +57,7 @@ def load_engines(training=True): state = torch.load(load_path, map_location=torch.device(cfg.device)) # check if config is defined in state, and re-initialize the model - if "config" in state: + if "config" in state and False: print("Model config definition in weights, re-loading...") config_state = state["config"] model = get_model( config=cfg.model.__class__( *config_state ), training=training ) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 251c8e4..b42dc1b 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -245,7 +245,13 @@ class AudioClassifier(nn.Module): self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens]) def forward(self, xi: Tensor, levels: list[int] ) -> Tensor: - return torch.stack( [ self.proj[l]( x ) for x, l in zip(xi, levels) ] ) + dtype = xi.dtype + device = xi.device + + xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ] + # pad if needed + xi = [ x if l == 0 else torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ], dim=-1 ) for x, l in zip(xi, levels) ] + return torch.stack( xi ) class Metrics(nn.Module): def __init__( @@ -273,10 +279,10 @@ class Metrics(nn.Module): ) for n_tokens in l_tokens ]) def calc_accuracy( self, inputs, targets, quant_levels ): - return sum( [ self.accuracy[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs ) + return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs ) def calc_precision( self, inputs, targets, quant_levels ): - return sum( [ self.precision[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs ) + return sum( [ self.precision[l]( input[:, :self.precision[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs ) def __call__(self, *args, **kwargs): return dict( @@ -421,6 +427,7 @@ class Base(nn.Module): audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False + tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True @@ -757,6 +764,12 @@ class Base(nn.Module): self.precision_metric = None self.metrics = Metrics( l_tokens ) + """ + if tie_classifier_to_embedding: + for i, proj in enumerate( self.classifiers.proj ): + self.classifiers.proj[i].weight = self.resps_emb.embeddings[i].weight + """ + def _forward( self, @@ -1063,7 +1076,7 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None, ): device = logits[0].device - + # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): if isinstance(input, str): @@ -1262,7 +1275,7 @@ class Base(nn.Module): ) if self.classifiers is not None: - x = self.classifiers(x, levels = quant_levels) * m + x = self.classifiers(x, levels = quant_levels) * m # Remove padding logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ] diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 92ea894..bc086db 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -42,6 +42,9 @@ def length_penalize( logits, length, factor=0.0, token=-1 ): # Simple way to ban tokens def ban_tokens( logits, tokens ): for token in tokens: + # token not in logits + if logits.shape[-1] >= token: + continue logits[:, token] = -float("inf") return logits