option to split classifier per-level instead of sharing one (at this point I'm just scrambling to try and cope with training a DAC model, the NAR is being a pain)
This commit is contained in:
parent
a7a6e0ac76
commit
65a8960305
|
@ -206,6 +206,7 @@ class Model:
|
||||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||||
attention: str = "auto"
|
attention: str = "auto"
|
||||||
audio_embedding_sums: bool = True
|
audio_embedding_sums: bool = True
|
||||||
|
split_classifiers: bool = False
|
||||||
dropout: float = 0.1 # adjustable dropout value
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||||
loss_factors: dict = field(default_factory=lambda: {})
|
loss_factors: dict = field(default_factory=lambda: {})
|
||||||
|
|
|
@ -615,6 +615,8 @@ class Dataset(_Dataset):
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||||
|
|
||||||
|
print(trim_length / cfg.dataset.frames_per_second)
|
||||||
|
|
||||||
for _ in range(cfg.dataset.max_prompts):
|
for _ in range(cfg.dataset.max_prompts):
|
||||||
path = random.choice(choices)
|
path = random.choice(choices)
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
|
|
|
@ -168,7 +168,7 @@ class AR_NAR(Base):
|
||||||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||||
else:
|
else:
|
||||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||||
|
|
||||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
|
|
||||||
|
@ -496,7 +496,7 @@ def example_usage():
|
||||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample( name, steps=1000 ):
|
def sample( name, steps=1000 ):
|
||||||
|
|
|
@ -54,3 +54,9 @@ try:
|
||||||
AVAILABLE_ARCHES.append("mamba2")
|
AVAILABLE_ARCHES.append("mamba2")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error importing `mamba` arch:", e)
|
print("Error importing `mamba` arch:", e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .mmfreelm import *
|
||||||
|
AVAILABLE_ARCHES.append("mmfreelm")
|
||||||
|
except Exception as e:
|
||||||
|
print("Error importing `mmfreelm` arch:", e)
|
6
vall_e/models/arch/mmfreelm.py
Normal file
6
vall_e/models/arch/mmfreelm.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
# https://github.com/ridgerchu/matmulfreellm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from mmfreelm.models import HGRNBitConfig, HGRNBitModel
|
|
@ -145,7 +145,7 @@ class AudioEmbedding_Old(nn.Module):
|
||||||
class AudioEmbedding(nn.Module):
|
class AudioEmbedding(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||||
token_dim: int, # dimensionality of the embedding
|
token_dim: int, # dimensionality of the embedding
|
||||||
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||||
):
|
):
|
||||||
|
@ -158,7 +158,6 @@ class AudioEmbedding(nn.Module):
|
||||||
#
|
#
|
||||||
self.sums = sums
|
self.sums = sums
|
||||||
|
|
||||||
# maintaining compat is hard
|
|
||||||
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
|
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
|
||||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||||
|
|
||||||
|
@ -170,6 +169,55 @@ class AudioEmbedding(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
# per-level classification
|
||||||
|
class AudioClassifier(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||||
|
token_dim: int, # dimensionality of the embedding
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens])
|
||||||
|
|
||||||
|
def forward(self, xi: Tensor, levels: list[int] ) -> Tensor:
|
||||||
|
return torch.stack( [ self.proj[l]( x ) for x, l in zip(xi, levels) ] )
|
||||||
|
|
||||||
|
class Metrics(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
l_tokens: int | list[int],
|
||||||
|
top_k = 10,
|
||||||
|
average="micro",
|
||||||
|
multidim_average="global",
|
||||||
|
ignore_index = -100
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.accuracy = nn.ModuleList([ MulticlassAccuracy(
|
||||||
|
n_tokens,
|
||||||
|
top_k=top_k,
|
||||||
|
average=average,
|
||||||
|
multidim_average=multidim_average,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
) for n_tokens in l_tokens ])
|
||||||
|
self.precision = nn.ModuleList([ MulticlassPrecision(
|
||||||
|
n_tokens,
|
||||||
|
top_k=top_k,
|
||||||
|
average=average,
|
||||||
|
multidim_average=multidim_average,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
) for n_tokens in l_tokens ])
|
||||||
|
|
||||||
|
def calc_accuracy( self, inputs, targets, quant_levels ):
|
||||||
|
return sum( [ self.accuracy[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
||||||
|
|
||||||
|
def calc_precision( self, inputs, targets, quant_levels ):
|
||||||
|
return sum( [ self.precision[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return dict(
|
||||||
|
acc=self.calc_accuracy(*args, **kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
class Base(nn.Module):
|
class Base(nn.Module):
|
||||||
# to-do: clean up this property mess
|
# to-do: clean up this property mess
|
||||||
|
|
||||||
|
@ -281,6 +329,9 @@ class Base(nn.Module):
|
||||||
n_prom_tokens = n_audio_tokens
|
n_prom_tokens = n_audio_tokens
|
||||||
n_resp_tokens = n_audio_tokens + self.causal_size
|
n_resp_tokens = n_audio_tokens + self.causal_size
|
||||||
|
|
||||||
|
audio_embedding_sums = self.config.audio_embedding_sums if self.config is not None else True
|
||||||
|
split_classifiers = self.config.split_classifiers if self.config is not None else True
|
||||||
|
|
||||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
self.langs_emb = None
|
self.langs_emb = None
|
||||||
self.tones_emb = None
|
self.tones_emb = None
|
||||||
|
@ -306,11 +357,11 @@ class Base(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.proms_emb = AudioEmbedding(
|
self.proms_emb = AudioEmbedding(
|
||||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
sums=audio_embedding_sums,
|
||||||
)
|
)
|
||||||
self.resps_emb = AudioEmbedding(
|
self.resps_emb = AudioEmbedding(
|
||||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
sums=audio_embedding_sums,
|
||||||
)
|
)
|
||||||
|
|
||||||
# useless since I actually removed using these with the input processing overhaul...
|
# useless since I actually removed using these with the input processing overhaul...
|
||||||
|
@ -533,13 +584,37 @@ class Base(nn.Module):
|
||||||
#initializer_cfg=initializer_cfg,
|
#initializer_cfg=initializer_cfg,
|
||||||
)
|
)
|
||||||
self.model.gradient_checkpointing = self.gradient_checkpointing
|
self.model.gradient_checkpointing = self.gradient_checkpointing
|
||||||
|
elif self.arch_type == "mmfreelm":
|
||||||
|
self.model = HGRNBitModel(HGRNBitConfig(
|
||||||
|
vocab_size=n_resp_tokens,
|
||||||
|
hidden_size=d_model,
|
||||||
|
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||||
|
intermediate_size=d_model*4,
|
||||||
|
num_hidden_layers=n_layers,
|
||||||
|
num_heads=n_heads,
|
||||||
|
#hidden_act="gelu",
|
||||||
|
#is_encoder_decoder=False,
|
||||||
|
#is_decoder=True,
|
||||||
|
attn_mode=hf_attention,
|
||||||
|
#gradient_checkpointing=self.gradient_checkpointing,
|
||||||
|
))
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||||
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||||
|
use_reentrant=False
|
||||||
|
))
|
||||||
|
|
||||||
|
#if training:
|
||||||
|
# self.model.training = True
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||||
|
|
||||||
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
|
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
|
||||||
|
|
||||||
|
if not split_classifiers:
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
|
self.classifiers = None
|
||||||
|
|
||||||
self.accuracy_metric = MulticlassAccuracy(
|
self.accuracy_metric = MulticlassAccuracy(
|
||||||
n_resp_tokens,
|
n_resp_tokens,
|
||||||
|
@ -557,6 +632,17 @@ class Base(nn.Module):
|
||||||
ignore_index=self.ignore_index,
|
ignore_index=self.ignore_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.metrics = None
|
||||||
|
else:
|
||||||
|
levels = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||||
|
|
||||||
|
self.classifier = None
|
||||||
|
self.classifiers = AudioClassifier( levels, d_model )
|
||||||
|
self.accuracy_metric = None
|
||||||
|
self.precision_metric = None
|
||||||
|
self.metrics = Metrics( levels )
|
||||||
|
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
inputs,
|
inputs,
|
||||||
|
@ -623,8 +709,16 @@ class Base(nn.Module):
|
||||||
x = self.model( hidden_states=x )
|
x = self.model( hidden_states=x )
|
||||||
elif self.arch_type == "bitnet":
|
elif self.arch_type == "bitnet":
|
||||||
x = self.model(x)
|
x = self.model(x)
|
||||||
|
elif self.arch_type == "mmfreelm":
|
||||||
|
x = self.model(
|
||||||
|
attention_mask=m,
|
||||||
|
inputs_embeds=x,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = x[0]
|
||||||
|
|
||||||
# output projection layer with masking
|
# output projection layer with masking
|
||||||
|
if self.classifier is not None:
|
||||||
x = self.classifier(x) * mask
|
x = self.classifier(x) * mask
|
||||||
|
|
||||||
return x, state, aux_loss
|
return x, state, aux_loss
|
||||||
|
@ -803,7 +897,7 @@ class Base(nn.Module):
|
||||||
# "nll" was in the original implementation and should actually just be called something else
|
# "nll" was in the original implementation and should actually just be called something else
|
||||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||||
)
|
)
|
||||||
self.stats = dict(
|
self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict(
|
||||||
acc = self.accuracy_metric( inputs, target ),
|
acc = self.accuracy_metric( inputs, target ),
|
||||||
# precision = self.precision_metric( inputs, target ),
|
# precision = self.precision_metric( inputs, target ),
|
||||||
)
|
)
|
||||||
|
@ -811,7 +905,7 @@ class Base(nn.Module):
|
||||||
self.loss = dict(
|
self.loss = dict(
|
||||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||||
)
|
)
|
||||||
self.stats = dict(
|
self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict(
|
||||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -887,6 +981,10 @@ class Base(nn.Module):
|
||||||
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
|
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
|
||||||
else:
|
else:
|
||||||
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||||
|
if self.metrics is not None:
|
||||||
|
metrics = self.metrics( batch["logits"], batch["targets"], quant_levels )
|
||||||
|
self.stats["acc"][name] = metrics["acc"]
|
||||||
|
else:
|
||||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -896,7 +994,6 @@ class Base(nn.Module):
|
||||||
quant_levels: int | list[int] | Tensor | None = None,
|
quant_levels: int | list[int] | Tensor | None = None,
|
||||||
state: dict | list | None = None,
|
state: dict | list | None = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||||
x, m = list_to_tensor(x_list)
|
x, m = list_to_tensor(x_list)
|
||||||
|
|
||||||
|
@ -913,6 +1010,10 @@ class Base(nn.Module):
|
||||||
device = x.device
|
device = x.device
|
||||||
batch_size = len(x_list)
|
batch_size = len(x_list)
|
||||||
|
|
||||||
|
# pure AR
|
||||||
|
if quant_levels is None:
|
||||||
|
quant_levels = [ 0 for _ in range(batch_size) ]
|
||||||
|
|
||||||
# pad our input and mask, but retain the original length by doing it after
|
# pad our input and mask, but retain the original length by doing it after
|
||||||
if self.l_padding and x.shape[1] % self.l_padding != 0:
|
if self.l_padding and x.shape[1] % self.l_padding != 0:
|
||||||
# pad input
|
# pad input
|
||||||
|
@ -934,6 +1035,9 @@ class Base(nn.Module):
|
||||||
state=state,
|
state=state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.classifiers is not None:
|
||||||
|
x = self.classifiers(x, levels = quant_levels) * m
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user