From 65a8960305bb57fcccc74308678b31d9565dd2d2 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 11 Jun 2024 22:28:59 -0500 Subject: [PATCH] option to split classifier per-level instead of sharing one (at this point I'm just scrambling to try and cope with training a DAC model, the NAR is being a pain) --- vall_e/config.py | 1 + vall_e/data.py | 2 + vall_e/models/ar_nar.py | 4 +- vall_e/models/arch/__init__.py | 8 +- vall_e/models/arch/mmfreelm.py | 6 ++ vall_e/models/base.py | 152 +++++++++++++++++++++++++++------ 6 files changed, 146 insertions(+), 27 deletions(-) create mode 100644 vall_e/models/arch/mmfreelm.py diff --git a/vall_e/config.py b/vall_e/config.py index f50a879..ee934b5 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -206,6 +206,7 @@ class Model: frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training attention: str = "auto" audio_embedding_sums: bool = True + split_classifiers: bool = False dropout: float = 0.1 # adjustable dropout value #loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good loss_factors: dict = field(default_factory=lambda: {}) diff --git a/vall_e/data.py b/vall_e/data.py index cd16a70..5e25857 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -615,6 +615,8 @@ class Dataset(_Dataset): prom_length = 0 trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) + print(trim_length / cfg.dataset.frames_per_second) + for _ in range(cfg.dataset.max_prompts): path = random.choice(choices) if cfg.dataset.use_hdf5: diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 03258a2..23bfd80 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -168,7 +168,7 @@ class AR_NAR(Base): quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ] else: # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ] + quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] @@ -496,7 +496,7 @@ def example_usage(): }, f"./data/{cfg.model.arch_type}.pth" ) """ - print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @torch.inference_mode() def sample( name, steps=1000 ): diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index 903a7c6..ee6dc41 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -53,4 +53,10 @@ try: AVAILABLE_ARCHES.append("mamba") AVAILABLE_ARCHES.append("mamba2") except Exception as e: - print("Error importing `mamba` arch:", e) \ No newline at end of file + print("Error importing `mamba` arch:", e) + +try: + from .mmfreelm import * + AVAILABLE_ARCHES.append("mmfreelm") +except Exception as e: + print("Error importing `mmfreelm` arch:", e) \ No newline at end of file diff --git a/vall_e/models/arch/mmfreelm.py b/vall_e/models/arch/mmfreelm.py new file mode 100644 index 0000000..86d2fbe --- /dev/null +++ b/vall_e/models/arch/mmfreelm.py @@ -0,0 +1,6 @@ +# https://github.com/ridgerchu/matmulfreellm + +import torch +import torch.nn.functional as F + +from mmfreelm.models import HGRNBitConfig, HGRNBitModel \ No newline at end of file diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c05bb9b..c30ad0e 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -145,7 +145,7 @@ class AudioEmbedding_Old(nn.Module): class AudioEmbedding(nn.Module): def __init__( self, - l_tokens: int, # list of number of tokens (needed because AR resps includes stop token) + l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token) token_dim: int, # dimensionality of the embedding sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better) ): @@ -158,7 +158,6 @@ class AudioEmbedding(nn.Module): # self.sums = sums - # maintaining compat is hard def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor: quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 @@ -170,6 +169,55 @@ class AudioEmbedding(nn.Module): return x +# per-level classification +class AudioClassifier(nn.Module): + def __init__( + self, + l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token) + token_dim: int, # dimensionality of the embedding + ): + super().__init__() + self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens]) + + def forward(self, xi: Tensor, levels: list[int] ) -> Tensor: + return torch.stack( [ self.proj[l]( x ) for x, l in zip(xi, levels) ] ) + +class Metrics(nn.Module): + def __init__( + self, + l_tokens: int | list[int], + top_k = 10, + average="micro", + multidim_average="global", + ignore_index = -100 + ): + super().__init__() + self.accuracy = nn.ModuleList([ MulticlassAccuracy( + n_tokens, + top_k=top_k, + average=average, + multidim_average=multidim_average, + ignore_index=ignore_index, + ) for n_tokens in l_tokens ]) + self.precision = nn.ModuleList([ MulticlassPrecision( + n_tokens, + top_k=top_k, + average=average, + multidim_average=multidim_average, + ignore_index=ignore_index, + ) for n_tokens in l_tokens ]) + + def calc_accuracy( self, inputs, targets, quant_levels ): + return sum( [ self.accuracy[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs ) + + def calc_precision( self, inputs, targets, quant_levels ): + return sum( [ self.precision[l]( input, target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs ) + + def __call__(self, *args, **kwargs): + return dict( + acc=self.calc_accuracy(*args, **kwargs), + ) + class Base(nn.Module): # to-do: clean up this property mess @@ -281,6 +329,9 @@ class Base(nn.Module): n_prom_tokens = n_audio_tokens n_resp_tokens = n_audio_tokens + self.causal_size + audio_embedding_sums = self.config.audio_embedding_sums if self.config is not None else True + split_classifiers = self.config.split_classifiers if self.config is not None else True + self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None self.tones_emb = None @@ -306,11 +357,11 @@ class Base(nn.Module): else: self.proms_emb = AudioEmbedding( [n_prom_tokens] * self.n_prom_levels, d_model, - sums=self.config.audio_embedding_sums if self.config is not None else True + sums=audio_embedding_sums, ) self.resps_emb = AudioEmbedding( [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, - sums=self.config.audio_embedding_sums if self.config is not None else True + sums=audio_embedding_sums, ) # useless since I actually removed using these with the input processing overhaul... @@ -533,29 +584,64 @@ class Base(nn.Module): #initializer_cfg=initializer_cfg, ) self.model.gradient_checkpointing = self.gradient_checkpointing + elif self.arch_type == "mmfreelm": + self.model = HGRNBitModel(HGRNBitConfig( + vocab_size=n_resp_tokens, + hidden_size=d_model, + max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds + intermediate_size=d_model*4, + num_hidden_layers=n_layers, + num_heads=n_heads, + #hidden_act="gelu", + #is_encoder_decoder=False, + #is_decoder=True, + attn_mode=hf_attention, + #gradient_checkpointing=self.gradient_checkpointing, + )) + + if self.gradient_checkpointing and not self.model.gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) + + #if training: + # self.model.training = True else: raise RuntimeError(f'Unknown arch specified: {self.arch_type}') if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention ) - self.classifier = nn.Linear(d_model, n_resp_tokens) + if not split_classifiers: + self.classifier = nn.Linear(d_model, n_resp_tokens) + self.classifiers = None + + self.accuracy_metric = MulticlassAccuracy( + n_resp_tokens, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=self.ignore_index, + ) - self.accuracy_metric = MulticlassAccuracy( - n_resp_tokens, - top_k=10, - average="micro", - multidim_average="global", - ignore_index=self.ignore_index, - ) + self.precision_metric = MulticlassPrecision( + n_resp_tokens, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=self.ignore_index, + ) + + self.metrics = None + else: + levels = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + + self.classifier = None + self.classifiers = AudioClassifier( levels, d_model ) + self.accuracy_metric = None + self.precision_metric = None + self.metrics = Metrics( levels ) - self.precision_metric = MulticlassPrecision( - n_resp_tokens, - top_k=10, - average="micro", - multidim_average="global", - ignore_index=self.ignore_index, - ) def _forward( self, @@ -623,9 +709,17 @@ class Base(nn.Module): x = self.model( hidden_states=x ) elif self.arch_type == "bitnet": x = self.model(x) + elif self.arch_type == "mmfreelm": + x = self.model( + attention_mask=m, + inputs_embeds=x, + ) + + x = x[0] # output projection layer with masking - x = self.classifier(x) * mask + if self.classifier is not None: + x = self.classifier(x) * mask return x, state, aux_loss @@ -803,7 +897,7 @@ class Base(nn.Module): # "nll" was in the original implementation and should actually just be called something else nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) ) - self.stats = dict( + self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict( acc = self.accuracy_metric( inputs, target ), # precision = self.precision_metric( inputs, target ), ) @@ -811,7 +905,7 @@ class Base(nn.Module): self.loss = dict( nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size ) - self.stats = dict( + self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict( acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size ) @@ -887,7 +981,11 @@ class Base(nn.Module): # this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed) else: self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size - self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size + if self.metrics is not None: + metrics = self.metrics( batch["logits"], batch["targets"], quant_levels ) + self.stats["acc"][name] = metrics["acc"] + else: + self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size def forward( self, @@ -896,7 +994,6 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None, state: dict | list | None = None, ): - x_list = self.inputs_to_embeddings( inputs, quant_levels ) x, m = list_to_tensor(x_list) @@ -912,6 +1009,10 @@ class Base(nn.Module): device = x.device batch_size = len(x_list) + + # pure AR + if quant_levels is None: + quant_levels = [ 0 for _ in range(batch_size) ] # pad our input and mask, but retain the original length by doing it after if self.l_padding and x.shape[1] % self.l_padding != 0: @@ -934,6 +1035,9 @@ class Base(nn.Module): state=state, ) + if self.classifiers is not None: + x = self.classifiers(x, levels = quant_levels) * m + # Remove padding logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]