actually have split classifiers working
This commit is contained in:
parent
692d09f9c1
commit
d53038a9e4
|
@ -200,17 +200,19 @@ class Dataset:
|
|||
def max_duration(self):
|
||||
return self.duration_range[1]
|
||||
|
||||
# collection of experimental variables that should not be tampered with unless you know what you're doing
|
||||
@dataclass()
|
||||
class ModelExperimentalSettings:
|
||||
hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly
|
||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
|
||||
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head
|
||||
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom)
|
||||
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
|
||||
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
||||
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
|
@ -230,7 +232,7 @@ class Model:
|
|||
dropout: float = 0.1 # adjustable dropout value
|
||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||
loss_factors: dict = field(default_factory=lambda: {})
|
||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"]) # + ["lang", "tone"] if you have your dataset labeled for such
|
||||
|
||||
experimental: dict | ModelExperimentalSettings | None = None # experimental settings
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def load_engines(training=True):
|
|||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
||||
# check if config is defined in state, and re-initialize the model
|
||||
if "config" in state:
|
||||
if "config" in state and False:
|
||||
print("Model config definition in weights, re-loading...")
|
||||
config_state = state["config"]
|
||||
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
|
||||
|
|
|
@ -245,7 +245,13 @@ class AudioClassifier(nn.Module):
|
|||
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens])
|
||||
|
||||
def forward(self, xi: Tensor, levels: list[int] ) -> Tensor:
|
||||
return torch.stack( [ self.proj[l]( x ) for x, l in zip(xi, levels) ] )
|
||||
dtype = xi.dtype
|
||||
device = xi.device
|
||||
|
||||
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
||||
# pad if needed
|
||||
xi = [ x if l == 0 else torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ], dim=-1 ) for x, l in zip(xi, levels) ]
|
||||
return torch.stack( xi )
|
||||
|
||||
class Metrics(nn.Module):
|
||||
def __init__(
|
||||
|
@ -273,10 +279,10 @@ class Metrics(nn.Module):
|
|||
) for n_tokens in l_tokens ])
|
||||
|
||||
def calc_accuracy( self, inputs, targets, quant_levels ):
|
||||
return sum( [ self.accuracy[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
||||
return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
||||
|
||||
def calc_precision( self, inputs, targets, quant_levels ):
|
||||
return sum( [ self.precision[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
||||
return sum( [ self.precision[l]( input[:, :self.precision[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return dict(
|
||||
|
@ -421,6 +427,7 @@ class Base(nn.Module):
|
|||
|
||||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
|
||||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
|
||||
|
@ -757,6 +764,12 @@ class Base(nn.Module):
|
|||
self.precision_metric = None
|
||||
self.metrics = Metrics( l_tokens )
|
||||
|
||||
"""
|
||||
if tie_classifier_to_embedding:
|
||||
for i, proj in enumerate( self.classifiers.proj ):
|
||||
self.classifiers.proj[i].weight = self.resps_emb.embeddings[i].weight
|
||||
"""
|
||||
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
|
|
|
@ -42,6 +42,9 @@ def length_penalize( logits, length, factor=0.0, token=-1 ):
|
|||
# Simple way to ban tokens
|
||||
def ban_tokens( logits, tokens ):
|
||||
for token in tokens:
|
||||
# token not in logits
|
||||
if logits.shape[-1] >= token:
|
||||
continue
|
||||
logits[:, token] = -float("inf")
|
||||
return logits
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user