actually have split classifiers working

This commit is contained in:
mrq 2024-07-19 15:33:31 -05:00
parent 692d09f9c1
commit d53038a9e4
4 changed files with 27 additions and 9 deletions

View File

@ -200,17 +200,19 @@ class Dataset:
def max_duration(self): def max_duration(self):
return self.duration_range[1] return self.duration_range[1]
# collection of experimental variables that should not be tampered with unless you know what you're doing
@dataclass() @dataclass()
class ModelExperimentalSettings: class ModelExperimentalSettings:
hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal) interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom)
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
kv_heads: int = 0 # MHA or GQA (for supported backends) kv_heads: int = 0 # MHA or GQA (for supported backends)
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
unified_position_ids: bool = True # False will generate position IDs partitioned for each section unified_position_ids: bool = True # False will generate position IDs partitioned for each section
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
# I really need to clean this up # I really need to clean this up
@dataclass() @dataclass()
@ -230,7 +232,7 @@ class Model:
dropout: float = 0.1 # adjustable dropout value dropout: float = 0.1 # adjustable dropout value
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good #loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
loss_factors: dict = field(default_factory=lambda: {}) loss_factors: dict = field(default_factory=lambda: {})
capabilities: list = field(default_factory=lambda: ["ar", "nar"]) capabilities: list = field(default_factory=lambda: ["ar", "nar"]) # + ["lang", "tone"] if you have your dataset labeled for such
experimental: dict | ModelExperimentalSettings | None = None # experimental settings experimental: dict | ModelExperimentalSettings | None = None # experimental settings

View File

@ -57,7 +57,7 @@ def load_engines(training=True):
state = torch.load(load_path, map_location=torch.device(cfg.device)) state = torch.load(load_path, map_location=torch.device(cfg.device))
# check if config is defined in state, and re-initialize the model # check if config is defined in state, and re-initialize the model
if "config" in state: if "config" in state and False:
print("Model config definition in weights, re-loading...") print("Model config definition in weights, re-loading...")
config_state = state["config"] config_state = state["config"]
model = get_model( config=cfg.model.__class__( *config_state ), training=training ) model = get_model( config=cfg.model.__class__( *config_state ), training=training )

View File

@ -245,7 +245,13 @@ class AudioClassifier(nn.Module):
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens]) self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens])
def forward(self, xi: Tensor, levels: list[int] ) -> Tensor: def forward(self, xi: Tensor, levels: list[int] ) -> Tensor:
return torch.stack( [ self.proj[l]( x ) for x, l in zip(xi, levels) ] ) dtype = xi.dtype
device = xi.device
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
# pad if needed
xi = [ x if l == 0 else torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ], dim=-1 ) for x, l in zip(xi, levels) ]
return torch.stack( xi )
class Metrics(nn.Module): class Metrics(nn.Module):
def __init__( def __init__(
@ -273,10 +279,10 @@ class Metrics(nn.Module):
) for n_tokens in l_tokens ]) ) for n_tokens in l_tokens ])
def calc_accuracy( self, inputs, targets, quant_levels ): def calc_accuracy( self, inputs, targets, quant_levels ):
return sum( [ self.accuracy[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs ) return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
def calc_precision( self, inputs, targets, quant_levels ): def calc_precision( self, inputs, targets, quant_levels ):
return sum( [ self.precision[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs ) return sum( [ self.precision[l]( input[:, :self.precision[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return dict( return dict(
@ -421,6 +427,7 @@ class Base(nn.Module):
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
@ -757,6 +764,12 @@ class Base(nn.Module):
self.precision_metric = None self.precision_metric = None
self.metrics = Metrics( l_tokens ) self.metrics = Metrics( l_tokens )
"""
if tie_classifier_to_embedding:
for i, proj in enumerate( self.classifiers.proj ):
self.classifiers.proj[i].weight = self.resps_emb.embeddings[i].weight
"""
def _forward( def _forward(
self, self,
@ -1262,7 +1275,7 @@ class Base(nn.Module):
) )
if self.classifiers is not None: if self.classifiers is not None:
x = self.classifiers(x, levels = quant_levels) * m x = self.classifiers(x, levels = quant_levels) * m
# Remove padding # Remove padding
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ] logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]

View File

@ -42,6 +42,9 @@ def length_penalize( logits, length, factor=0.0, token=-1 ):
# Simple way to ban tokens # Simple way to ban tokens
def ban_tokens( logits, tokens ): def ban_tokens( logits, tokens ):
for token in tokens: for token in tokens:
# token not in logits
if logits.shape[-1] >= token:
continue
logits[:, token] = -float("inf") logits[:, token] = -float("inf")
return logits return logits