option to split classifier per-level instead of sharing one (at this point I'm just scrambling to try and cope with training a DAC model, the NAR is being a pain)

This commit is contained in:
mrq 2024-06-11 22:28:59 -05:00
parent a7a6e0ac76
commit 65a8960305
6 changed files with 146 additions and 27 deletions

View File

@ -206,6 +206,7 @@ class Model:
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
attention: str = "auto"
audio_embedding_sums: bool = True
split_classifiers: bool = False
dropout: float = 0.1 # adjustable dropout value
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
loss_factors: dict = field(default_factory=lambda: {})

View File

@ -615,6 +615,8 @@ class Dataset(_Dataset):
prom_length = 0
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
print(trim_length / cfg.dataset.frames_per_second)
for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices)
if cfg.dataset.use_hdf5:

View File

@ -168,7 +168,7 @@ class AR_NAR(Base):
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
else:
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
@ -496,7 +496,7 @@ def example_usage():
}, f"./data/{cfg.model.arch_type}.pth" )
"""
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode()
def sample( name, steps=1000 ):

View File

@ -53,4 +53,10 @@ try:
AVAILABLE_ARCHES.append("mamba")
AVAILABLE_ARCHES.append("mamba2")
except Exception as e:
print("Error importing `mamba` arch:", e)
print("Error importing `mamba` arch:", e)
try:
from .mmfreelm import *
AVAILABLE_ARCHES.append("mmfreelm")
except Exception as e:
print("Error importing `mmfreelm` arch:", e)

View File

@ -0,0 +1,6 @@
# https://github.com/ridgerchu/matmulfreellm
import torch
import torch.nn.functional as F
from mmfreelm.models import HGRNBitConfig, HGRNBitModel

View File

@ -145,7 +145,7 @@ class AudioEmbedding_Old(nn.Module):
class AudioEmbedding(nn.Module):
def __init__(
self,
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
):
@ -158,7 +158,6 @@ class AudioEmbedding(nn.Module):
#
self.sums = sums
# maintaining compat is hard
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
@ -170,6 +169,55 @@ class AudioEmbedding(nn.Module):
return x
# per-level classification
class AudioClassifier(nn.Module):
def __init__(
self,
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
):
super().__init__()
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens])
def forward(self, xi: Tensor, levels: list[int] ) -> Tensor:
return torch.stack( [ self.proj[l]( x ) for x, l in zip(xi, levels) ] )
class Metrics(nn.Module):
def __init__(
self,
l_tokens: int | list[int],
top_k = 10,
average="micro",
multidim_average="global",
ignore_index = -100
):
super().__init__()
self.accuracy = nn.ModuleList([ MulticlassAccuracy(
n_tokens,
top_k=top_k,
average=average,
multidim_average=multidim_average,
ignore_index=ignore_index,
) for n_tokens in l_tokens ])
self.precision = nn.ModuleList([ MulticlassPrecision(
n_tokens,
top_k=top_k,
average=average,
multidim_average=multidim_average,
ignore_index=ignore_index,
) for n_tokens in l_tokens ])
def calc_accuracy( self, inputs, targets, quant_levels ):
return sum( [ self.accuracy[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
def calc_precision( self, inputs, targets, quant_levels ):
return sum( [ self.precision[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
def __call__(self, *args, **kwargs):
return dict(
acc=self.calc_accuracy(*args, **kwargs),
)
class Base(nn.Module):
# to-do: clean up this property mess
@ -281,6 +329,9 @@ class Base(nn.Module):
n_prom_tokens = n_audio_tokens
n_resp_tokens = n_audio_tokens + self.causal_size
audio_embedding_sums = self.config.audio_embedding_sums if self.config is not None else True
split_classifiers = self.config.split_classifiers if self.config is not None else True
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
self.tones_emb = None
@ -306,11 +357,11 @@ class Base(nn.Module):
else:
self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model,
sums=self.config.audio_embedding_sums if self.config is not None else True
sums=audio_embedding_sums,
)
self.resps_emb = AudioEmbedding(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
sums=self.config.audio_embedding_sums if self.config is not None else True
sums=audio_embedding_sums,
)
# useless since I actually removed using these with the input processing overhaul...
@ -533,29 +584,64 @@ class Base(nn.Module):
#initializer_cfg=initializer_cfg,
)
self.model.gradient_checkpointing = self.gradient_checkpointing
elif self.arch_type == "mmfreelm":
self.model = HGRNBitModel(HGRNBitConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_heads=n_heads,
#hidden_act="gelu",
#is_encoder_decoder=False,
#is_decoder=True,
attn_mode=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
#if training:
# self.model.training = True
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
self.classifier = nn.Linear(d_model, n_resp_tokens)
if not split_classifiers:
self.classifier = nn.Linear(d_model, n_resp_tokens)
self.classifiers = None
self.accuracy_metric = MulticlassAccuracy(
n_resp_tokens,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=self.ignore_index,
)
self.accuracy_metric = MulticlassAccuracy(
n_resp_tokens,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=self.ignore_index,
)
self.precision_metric = MulticlassPrecision(
n_resp_tokens,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=self.ignore_index,
)
self.metrics = None
else:
levels = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
self.classifier = None
self.classifiers = AudioClassifier( levels, d_model )
self.accuracy_metric = None
self.precision_metric = None
self.metrics = Metrics( levels )
self.precision_metric = MulticlassPrecision(
n_resp_tokens,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=self.ignore_index,
)
def _forward(
self,
@ -623,9 +709,17 @@ class Base(nn.Module):
x = self.model( hidden_states=x )
elif self.arch_type == "bitnet":
x = self.model(x)
elif self.arch_type == "mmfreelm":
x = self.model(
attention_mask=m,
inputs_embeds=x,
)
x = x[0]
# output projection layer with masking
x = self.classifier(x) * mask
if self.classifier is not None:
x = self.classifier(x) * mask
return x, state, aux_loss
@ -803,7 +897,7 @@ class Base(nn.Module):
# "nll" was in the original implementation and should actually just be called something else
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
)
self.stats = dict(
self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict(
acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ),
)
@ -811,7 +905,7 @@ class Base(nn.Module):
self.loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
)
self.stats = dict(
self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
)
@ -887,7 +981,11 @@ class Base(nn.Module):
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
else:
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
if self.metrics is not None:
metrics = self.metrics( batch["logits"], batch["targets"], quant_levels )
self.stats["acc"][name] = metrics["acc"]
else:
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
def forward(
self,
@ -896,7 +994,6 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None,
state: dict | list | None = None,
):
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list)
@ -912,6 +1009,10 @@ class Base(nn.Module):
device = x.device
batch_size = len(x_list)
# pure AR
if quant_levels is None:
quant_levels = [ 0 for _ in range(batch_size) ]
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
@ -934,6 +1035,9 @@ class Base(nn.Module):
state=state,
)
if self.classifiers is not None:
x = self.classifiers(x, levels = quant_levels) * m
# Remove padding
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]