actually have split classifiers working
This commit is contained in:
parent
692d09f9c1
commit
d53038a9e4
|
@ -200,17 +200,19 @@ class Dataset:
|
||||||
def max_duration(self):
|
def max_duration(self):
|
||||||
return self.duration_range[1]
|
return self.duration_range[1]
|
||||||
|
|
||||||
|
# collection of experimental variables that should not be tampered with unless you know what you're doing
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class ModelExperimentalSettings:
|
class ModelExperimentalSettings:
|
||||||
hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly
|
hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly
|
||||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
|
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
|
||||||
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head
|
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom)
|
||||||
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||||
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
|
||||||
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
||||||
|
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
||||||
|
|
||||||
# I really need to clean this up
|
# I really need to clean this up
|
||||||
@dataclass()
|
@dataclass()
|
||||||
|
@ -230,7 +232,7 @@ class Model:
|
||||||
dropout: float = 0.1 # adjustable dropout value
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||||
loss_factors: dict = field(default_factory=lambda: {})
|
loss_factors: dict = field(default_factory=lambda: {})
|
||||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
capabilities: list = field(default_factory=lambda: ["ar", "nar"]) # + ["lang", "tone"] if you have your dataset labeled for such
|
||||||
|
|
||||||
experimental: dict | ModelExperimentalSettings | None = None # experimental settings
|
experimental: dict | ModelExperimentalSettings | None = None # experimental settings
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ def load_engines(training=True):
|
||||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||||
|
|
||||||
# check if config is defined in state, and re-initialize the model
|
# check if config is defined in state, and re-initialize the model
|
||||||
if "config" in state:
|
if "config" in state and False:
|
||||||
print("Model config definition in weights, re-loading...")
|
print("Model config definition in weights, re-loading...")
|
||||||
config_state = state["config"]
|
config_state = state["config"]
|
||||||
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
|
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
|
||||||
|
|
|
@ -245,7 +245,13 @@ class AudioClassifier(nn.Module):
|
||||||
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens])
|
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens])
|
||||||
|
|
||||||
def forward(self, xi: Tensor, levels: list[int] ) -> Tensor:
|
def forward(self, xi: Tensor, levels: list[int] ) -> Tensor:
|
||||||
return torch.stack( [ self.proj[l]( x ) for x, l in zip(xi, levels) ] )
|
dtype = xi.dtype
|
||||||
|
device = xi.device
|
||||||
|
|
||||||
|
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
||||||
|
# pad if needed
|
||||||
|
xi = [ x if l == 0 else torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ], dim=-1 ) for x, l in zip(xi, levels) ]
|
||||||
|
return torch.stack( xi )
|
||||||
|
|
||||||
class Metrics(nn.Module):
|
class Metrics(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -273,10 +279,10 @@ class Metrics(nn.Module):
|
||||||
) for n_tokens in l_tokens ])
|
) for n_tokens in l_tokens ])
|
||||||
|
|
||||||
def calc_accuracy( self, inputs, targets, quant_levels ):
|
def calc_accuracy( self, inputs, targets, quant_levels ):
|
||||||
return sum( [ self.accuracy[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
||||||
|
|
||||||
def calc_precision( self, inputs, targets, quant_levels ):
|
def calc_precision( self, inputs, targets, quant_levels ):
|
||||||
return sum( [ self.precision[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
return sum( [ self.precision[l]( input[:, :self.precision[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return dict(
|
return dict(
|
||||||
|
@ -421,6 +427,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||||
|
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
|
||||||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||||
|
|
||||||
|
@ -757,6 +764,12 @@ class Base(nn.Module):
|
||||||
self.precision_metric = None
|
self.precision_metric = None
|
||||||
self.metrics = Metrics( l_tokens )
|
self.metrics = Metrics( l_tokens )
|
||||||
|
|
||||||
|
"""
|
||||||
|
if tie_classifier_to_embedding:
|
||||||
|
for i, proj in enumerate( self.classifiers.proj ):
|
||||||
|
self.classifiers.proj[i].weight = self.resps_emb.embeddings[i].weight
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
|
@ -1262,7 +1275,7 @@ class Base(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.classifiers is not None:
|
if self.classifiers is not None:
|
||||||
x = self.classifiers(x, levels = quant_levels) * m
|
x = self.classifiers(x, levels = quant_levels) * m
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
||||||
|
|
|
@ -42,6 +42,9 @@ def length_penalize( logits, length, factor=0.0, token=-1 ):
|
||||||
# Simple way to ban tokens
|
# Simple way to ban tokens
|
||||||
def ban_tokens( logits, tokens ):
|
def ban_tokens( logits, tokens ):
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
|
# token not in logits
|
||||||
|
if logits.shape[-1] >= token:
|
||||||
|
continue
|
||||||
logits[:, token] = -float("inf")
|
logits[:, token] = -float("inf")
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user