diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 0b38be9..8a7fbbc 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -66,16 +66,11 @@ class AR_NAR(Base): return cfg.model.tones @property - def recurrent_chunk_size(self) -> int: - return 0 - - """ - @property - def rotary_embedding_base(self) -> float: - if hasattr(self, "config") and self.config: - return self.config.rotary_embedding_base - return cfg.model.rotary_embedding_base - """ + def causal_size(self) -> int: + # 1 for the stop token + # governs how much to shift the logits by + # could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it + return 1 if self.causal else 0 @property def interleave(self) -> bool: @@ -241,7 +236,7 @@ class AR_NAR(Base): max_steps *= self.n_prom_levels # get next in sequence - for n in trange(max_steps // max(1, self.recurrent_chunk_size), desc="AR"): + for n in trange(max_steps // max(1, self.causal_size), desc="AR"): # experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this. # UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2 if max_resp_context > 0: @@ -463,9 +458,11 @@ def example_usage(): engine = Engine(model=model, optimizer=optimizer) + """ torch.save( { 'module': model.state_dict() }, f"./data/{cfg.model.arch_type}.pth" ) + """ print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @@ -498,9 +495,11 @@ def example_usage(): tqdm.write(f"{stats}") + """ torch.save( { 'module': model.state_dict() }, f"./data/{cfg.model.arch_type}.pth" ) + """ #sample("init", 5) train() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 43a794d..b1dec02 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -202,13 +202,9 @@ class Base(nn.Module): raise NotImplementedError @property - def recurrent_chunk_size(self) -> int: + def causal_size(self) -> int: raise NotImplementedError - @property - def rotary_embedding_base(self) -> float: - return 10000 - @property def interleave(self) -> bool: return False @@ -271,7 +267,7 @@ class Base(nn.Module): # +1 to include the stop token n_prom_tokens = n_audio_tokens - n_resp_tokens = n_audio_tokens + 1 # (1 if self.causal else 0) interoperability + n_resp_tokens = n_audio_tokens + self.causal_size self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None @@ -456,12 +452,12 @@ class Base(nn.Module): use_biases=self.version < 3, use_glu=self.version >= 3, - chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0, - recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0, + chunkwise_recurrent=self.causal and self.causal_size > 0, + recurrent_chunkwise_size=self.causal_size if self.causal else 0, no_output_layer=True, decoder_normalize_before=True, - rotary_embedding_base=self.rotary_embedding_base, # 10000 + rotary_embedding_base=10000 ) if n_experts > 1: @@ -486,7 +482,7 @@ class Base(nn.Module): activation_fn="gelu", use_glu=False, # self.version >= 3, - recurrent_chunk_size=self.recurrent_chunk_size if self.causal else 0, + recurrent_chunk_size=self.causal_size if self.causal else 0, decoder_normalize_before=True, deepnorm=False, @@ -710,8 +706,9 @@ class Base(nn.Module): if quant_levels is not None and quant_levels[i] > 0: continue - logits[i] = logits[i][..., :-1, :] # shift the target so that token n... - target_list[i] = target_list[i][..., 1:] # predicts token n + 1 + l = self.causal_size + logits[i] = logits[i][..., :-l, :] # shift the target so that token n... + target_list[i] = target_list[i][..., l:] # predicts token n + 1 # see comments for the split-loss calc cross_entropy call if False: @@ -769,8 +766,9 @@ class Base(nn.Module): # for the AR, shift sequence so that it predicts the next token # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it) if quant_level is None or quant_level == 0: - logit = logit[..., :-1, :] # get all but the final logit - input = input[..., 1:] # shift sequence to the right by one + l = self.causal_size + logit = logit[..., :-l, :] + input = input[..., l:] # shift sequence to the right by one (or causal chunk size) if name not in info: info[name] = { @@ -802,15 +800,6 @@ class Base(nn.Module): else: self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size - - # accuracy sometimes breaks for mamba - - # to-do: compute loss per individual batch to scale per RVQ level - """ - rvq_loss_factor = self.loss_factor("quant") - if isinstance( rvq_loss_factor, list ): - ... - """ def forward( self, @@ -898,12 +887,8 @@ class Base(nn.Module): if quant_levels is not None: logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ] # (AR chunkwise) return the last chunkwise piece - elif self.causal and self.recurrent_chunk_size > 0: - logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ] - # (AR) return just the last code - # Recurrent decoding relies on the last token in the logits, because each token predicts the next token in the sequence (obviously) - else: - logits = [ logit[-1:] for logit in logits ] + elif self.causal: + logits = [ logit[-self.causal_size:] for logit in logits ] devices = [ logit.device for logit in logits ] logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index fb99546..d2a8dbd 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -315,9 +315,11 @@ def example_usage(): engine = Engine(model=model, optimizer=optimizer) + """ torch.save( { 'module': model.state_dict() }, f"./data/{cfg.model.arch_type}.pth" ) + """ print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @@ -400,9 +402,11 @@ def example_usage(): tqdm.write(f"{stats}") + """ torch.save( { 'module': model.state_dict() }, f"./data/{cfg.model.arch_type}.pth" ) + """ #sample("init", 5) train()