From 3f1070f575d2dd5ffc8df675fc892d64f3bd8509 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 2 Mar 2025 22:36:25 -0600 Subject: [PATCH] tweaks --- docs/README.md | 1 + vall_e/data.py | 5 ++++- vall_e/models/base_v2.py | 16 +++++++--------- vall_e/utils/__init__.py | 1 + vall_e/utils/utils.py | 4 ++++ 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/docs/README.md b/docs/README.md index 4d3e8f0..f96d898 100644 --- a/docs/README.md +++ b/docs/README.md @@ -112,6 +112,7 @@ The model even working at all might entirely be a fluke. A naive embedding implementation (`./vall_e/models/base.py`) manages to "just work" for EnCodec, while other audio codecs (DAC, `nvidia/audio-codec-44khz`) fail to converge meaningfully. A more codec-aware embedding/classifier implementation (`./vall_e/models/base_v2.py`) fails to properly learn all levels for any codec, even with all the additional cruft to help things. Even scaling the model up just has the gradients seem a little more chaotic with about the same training progression. +* However it seems just giving it time will have things eventually sort itself out, maybe. ## Notices and Citations diff --git a/vall_e/data.py b/vall_e/data.py index f4f1787..81ae375 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -751,7 +751,10 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): metadata = {} if cfg.dataset.use_metadata and metadata_path.exists(): - metadata = json_read( metadata_path ) + try: + metadata = json_read( metadata_path ) + except Exception as e: + return {} if len(metadata) == 0: return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate ) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 5b15ba8..27fb571 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision from .arch import * -from ..utils import ml, clamp, mean +from ..utils import ml, clamp, mean, logit_normalization from ..samplers import * # yuck, kind of needed @@ -107,7 +107,8 @@ class ResidualAudioEncoder(nn.Module): # ( seq_len, dim ) => ( seq_len, levels, dim ) x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1) x = x + self.pos_embedding - x, _ = self.cross_attn( x, x, x ) + attn, _ = self.cross_attn( x, x, x ) + x = x + attn x = self.proj( x.mean(dim=1) ) return x @@ -135,7 +136,8 @@ class ResidualAudioDecoder(nn.Module): def _forward( self, x: Tensor ) -> Tensor: seq_len, resp_levels = x.shape[0], len(self.projs) x = torch.stack([proj(x) for proj in self.projs], dim=1) - x, _ = self.cross_attn( x, x, x ) + attn, _ = self.cross_attn( x, x, x ) + x = x + attn x = self.head( x ) x = x.view( resp_levels, seq_len, -1 ) return x @@ -865,10 +867,6 @@ class Base_V2(nn.Module): return input - def _logit_normalization( logit ): - norms = torch.norm(logit, p=2, dim=-1, keepdim=True) + 1e-7 - return torch.div(logit, norms) / self.logit_normalization - def _calc_loss( logit, sequence, causal = True, level = None ): # filter tokens that exceed the vocab size sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence ) @@ -888,10 +886,10 @@ class Base_V2(nn.Module): if self.logit_normalization: # it would probably be better to unsqueeze then squeeze to avoid code duplication but who cares if not batched: - logit = _logit_normalization( logit ) + logit = logit_normalization( logit, self.logit_normalization ) else: for i, l in enumerate( logit ): - logit[i] = _logit_normalization( l ) + logit[i] = logit_normalization( l, self.logit_normalization ) # flatten batch if batched: diff --git a/vall_e/utils/__init__.py b/vall_e/utils/__init__.py index e6c03a9..b97a03f 100755 --- a/vall_e/utils/__init__.py +++ b/vall_e/utils/__init__.py @@ -18,4 +18,5 @@ from .utils import ( convert_kwargs, coerce_dtype, mean, + logit_normalization, ) \ No newline at end of file diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 088958e..c8eb32c 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -37,6 +37,10 @@ def mean( l ): return 0 return sum(l) / len(l) +def logit_normalization( logit, factor=1, eps=1.0e-7 ): + norms = torch.norm(logit, p=2, dim=-1, keepdim=True) + eps + return torch.div(logit, norms) / factor + # removes prefix from key in a dict # useful for mapping args like ar_temperature => temperature def convert_kwargs( kwargs, prefix ):