option to split classifier per-level instead of sharing one (at this point I'm just scrambling to try and cope with training a DAC model, the NAR is being a pain)
This commit is contained in:
parent
a7a6e0ac76
commit
65a8960305
|
@ -206,6 +206,7 @@ class Model:
|
|||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
attention: str = "auto"
|
||||
audio_embedding_sums: bool = True
|
||||
split_classifiers: bool = False
|
||||
dropout: float = 0.1 # adjustable dropout value
|
||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||
loss_factors: dict = field(default_factory=lambda: {})
|
||||
|
|
|
@ -615,6 +615,8 @@ class Dataset(_Dataset):
|
|||
prom_length = 0
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||
|
||||
print(trim_length / cfg.dataset.frames_per_second)
|
||||
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
if cfg.dataset.use_hdf5:
|
||||
|
|
|
@ -168,7 +168,7 @@ class AR_NAR(Base):
|
|||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
else:
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
|
||||
|
@ -496,7 +496,7 @@ def example_usage():
|
|||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=1000 ):
|
||||
|
|
|
@ -53,4 +53,10 @@ try:
|
|||
AVAILABLE_ARCHES.append("mamba")
|
||||
AVAILABLE_ARCHES.append("mamba2")
|
||||
except Exception as e:
|
||||
print("Error importing `mamba` arch:", e)
|
||||
print("Error importing `mamba` arch:", e)
|
||||
|
||||
try:
|
||||
from .mmfreelm import *
|
||||
AVAILABLE_ARCHES.append("mmfreelm")
|
||||
except Exception as e:
|
||||
print("Error importing `mmfreelm` arch:", e)
|
6
vall_e/models/arch/mmfreelm.py
Normal file
6
vall_e/models/arch/mmfreelm.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
# https://github.com/ridgerchu/matmulfreellm
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmfreelm.models import HGRNBitConfig, HGRNBitModel
|
|
@ -145,7 +145,7 @@ class AudioEmbedding_Old(nn.Module):
|
|||
class AudioEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
||||
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||
):
|
||||
|
@ -158,7 +158,6 @@ class AudioEmbedding(nn.Module):
|
|||
#
|
||||
self.sums = sums
|
||||
|
||||
# maintaining compat is hard
|
||||
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
|
||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||
|
||||
|
@ -170,6 +169,55 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
# per-level classification
|
||||
class AudioClassifier(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens])
|
||||
|
||||
def forward(self, xi: Tensor, levels: list[int] ) -> Tensor:
|
||||
return torch.stack( [ self.proj[l]( x ) for x, l in zip(xi, levels) ] )
|
||||
|
||||
class Metrics(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: int | list[int],
|
||||
top_k = 10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
):
|
||||
super().__init__()
|
||||
self.accuracy = nn.ModuleList([ MulticlassAccuracy(
|
||||
n_tokens,
|
||||
top_k=top_k,
|
||||
average=average,
|
||||
multidim_average=multidim_average,
|
||||
ignore_index=ignore_index,
|
||||
) for n_tokens in l_tokens ])
|
||||
self.precision = nn.ModuleList([ MulticlassPrecision(
|
||||
n_tokens,
|
||||
top_k=top_k,
|
||||
average=average,
|
||||
multidim_average=multidim_average,
|
||||
ignore_index=ignore_index,
|
||||
) for n_tokens in l_tokens ])
|
||||
|
||||
def calc_accuracy( self, inputs, targets, quant_levels ):
|
||||
return sum( [ self.accuracy[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
||||
|
||||
def calc_precision( self, inputs, targets, quant_levels ):
|
||||
return sum( [ self.precision[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return dict(
|
||||
acc=self.calc_accuracy(*args, **kwargs),
|
||||
)
|
||||
|
||||
class Base(nn.Module):
|
||||
# to-do: clean up this property mess
|
||||
|
||||
|
@ -281,6 +329,9 @@ class Base(nn.Module):
|
|||
n_prom_tokens = n_audio_tokens
|
||||
n_resp_tokens = n_audio_tokens + self.causal_size
|
||||
|
||||
audio_embedding_sums = self.config.audio_embedding_sums if self.config is not None else True
|
||||
split_classifiers = self.config.split_classifiers if self.config is not None else True
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
self.tones_emb = None
|
||||
|
@ -306,11 +357,11 @@ class Base(nn.Module):
|
|||
else:
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||
sums=audio_embedding_sums,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||
sums=audio_embedding_sums,
|
||||
)
|
||||
|
||||
# useless since I actually removed using these with the input processing overhaul...
|
||||
|
@ -533,29 +584,64 @@ class Base(nn.Module):
|
|||
#initializer_cfg=initializer_cfg,
|
||||
)
|
||||
self.model.gradient_checkpointing = self.gradient_checkpointing
|
||||
elif self.arch_type == "mmfreelm":
|
||||
self.model = HGRNBitModel(HGRNBitConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_heads=n_heads,
|
||||
#hidden_act="gelu",
|
||||
#is_encoder_decoder=False,
|
||||
#is_decoder=True,
|
||||
attn_mode=hf_attention,
|
||||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
))
|
||||
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
|
||||
#if training:
|
||||
# self.model.training = True
|
||||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
||||
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
|
||||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
if not split_classifiers:
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
self.classifiers = None
|
||||
|
||||
self.accuracy_metric = MulticlassAccuracy(
|
||||
n_resp_tokens,
|
||||
top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index=self.ignore_index,
|
||||
)
|
||||
|
||||
self.accuracy_metric = MulticlassAccuracy(
|
||||
n_resp_tokens,
|
||||
top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index=self.ignore_index,
|
||||
)
|
||||
self.precision_metric = MulticlassPrecision(
|
||||
n_resp_tokens,
|
||||
top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index=self.ignore_index,
|
||||
)
|
||||
|
||||
self.metrics = None
|
||||
else:
|
||||
levels = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
|
||||
self.classifier = None
|
||||
self.classifiers = AudioClassifier( levels, d_model )
|
||||
self.accuracy_metric = None
|
||||
self.precision_metric = None
|
||||
self.metrics = Metrics( levels )
|
||||
|
||||
self.precision_metric = MulticlassPrecision(
|
||||
n_resp_tokens,
|
||||
top_k=10,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index=self.ignore_index,
|
||||
)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
|
@ -623,9 +709,17 @@ class Base(nn.Module):
|
|||
x = self.model( hidden_states=x )
|
||||
elif self.arch_type == "bitnet":
|
||||
x = self.model(x)
|
||||
elif self.arch_type == "mmfreelm":
|
||||
x = self.model(
|
||||
attention_mask=m,
|
||||
inputs_embeds=x,
|
||||
)
|
||||
|
||||
x = x[0]
|
||||
|
||||
# output projection layer with masking
|
||||
x = self.classifier(x) * mask
|
||||
if self.classifier is not None:
|
||||
x = self.classifier(x) * mask
|
||||
|
||||
return x, state, aux_loss
|
||||
|
||||
|
@ -803,7 +897,7 @@ class Base(nn.Module):
|
|||
# "nll" was in the original implementation and should actually just be called something else
|
||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||
)
|
||||
self.stats = dict(
|
||||
self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict(
|
||||
acc = self.accuracy_metric( inputs, target ),
|
||||
# precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
|
@ -811,7 +905,7 @@ class Base(nn.Module):
|
|||
self.loss = dict(
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||
)
|
||||
self.stats = dict(
|
||||
self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict(
|
||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||
)
|
||||
|
||||
|
@ -887,7 +981,11 @@ class Base(nn.Module):
|
|||
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
|
||||
else:
|
||||
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics( batch["logits"], batch["targets"], quant_levels )
|
||||
self.stats["acc"][name] = metrics["acc"]
|
||||
else:
|
||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -896,7 +994,6 @@ class Base(nn.Module):
|
|||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
state: dict | list | None = None,
|
||||
):
|
||||
|
||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
|
@ -912,6 +1009,10 @@ class Base(nn.Module):
|
|||
|
||||
device = x.device
|
||||
batch_size = len(x_list)
|
||||
|
||||
# pure AR
|
||||
if quant_levels is None:
|
||||
quant_levels = [ 0 for _ in range(batch_size) ]
|
||||
|
||||
# pad our input and mask, but retain the original length by doing it after
|
||||
if self.l_padding and x.shape[1] % self.l_padding != 0:
|
||||
|
@ -934,6 +1035,9 @@ class Base(nn.Module):
|
|||
state=state,
|
||||
)
|
||||
|
||||
if self.classifiers is not None:
|
||||
x = self.classifiers(x, levels = quant_levels) * m
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user