This commit is contained in:
mrq 2025-03-02 22:36:25 -06:00
parent 4afa4ccce5
commit 3f1070f575
5 changed files with 17 additions and 10 deletions

View File

@ -112,6 +112,7 @@ The model even working at all might entirely be a fluke.
A naive embedding implementation (`./vall_e/models/base.py`) manages to "just work" for EnCodec, while other audio codecs (DAC, `nvidia/audio-codec-44khz`) fail to converge meaningfully.
A more codec-aware embedding/classifier implementation (`./vall_e/models/base_v2.py`) fails to properly learn all levels for any codec, even with all the additional cruft to help things. Even scaling the model up just has the gradients seem a little more chaotic with about the same training progression.
* However it seems just giving it time will have things eventually sort itself out, maybe.
## Notices and Citations

View File

@ -751,7 +751,10 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
metadata = {}
if cfg.dataset.use_metadata and metadata_path.exists():
try:
metadata = json_read( metadata_path )
except Exception as e:
return {}
if len(metadata) == 0:
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )

View File

@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .arch import *
from ..utils import ml, clamp, mean
from ..utils import ml, clamp, mean, logit_normalization
from ..samplers import *
# yuck, kind of needed
@ -107,7 +107,8 @@ class ResidualAudioEncoder(nn.Module):
# ( seq_len, dim ) => ( seq_len, levels, dim )
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
x = x + self.pos_embedding
x, _ = self.cross_attn( x, x, x )
attn, _ = self.cross_attn( x, x, x )
x = x + attn
x = self.proj( x.mean(dim=1) )
return x
@ -135,7 +136,8 @@ class ResidualAudioDecoder(nn.Module):
def _forward( self, x: Tensor ) -> Tensor:
seq_len, resp_levels = x.shape[0], len(self.projs)
x = torch.stack([proj(x) for proj in self.projs], dim=1)
x, _ = self.cross_attn( x, x, x )
attn, _ = self.cross_attn( x, x, x )
x = x + attn
x = self.head( x )
x = x.view( resp_levels, seq_len, -1 )
return x
@ -865,10 +867,6 @@ class Base_V2(nn.Module):
return input
def _logit_normalization( logit ):
norms = torch.norm(logit, p=2, dim=-1, keepdim=True) + 1e-7
return torch.div(logit, norms) / self.logit_normalization
def _calc_loss( logit, sequence, causal = True, level = None ):
# filter tokens that exceed the vocab size
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
@ -888,10 +886,10 @@ class Base_V2(nn.Module):
if self.logit_normalization:
# it would probably be better to unsqueeze then squeeze to avoid code duplication but who cares
if not batched:
logit = _logit_normalization( logit )
logit = logit_normalization( logit, self.logit_normalization )
else:
for i, l in enumerate( logit ):
logit[i] = _logit_normalization( l )
logit[i] = logit_normalization( l, self.logit_normalization )
# flatten batch
if batched:

View File

@ -18,4 +18,5 @@ from .utils import (
convert_kwargs,
coerce_dtype,
mean,
logit_normalization,
)

View File

@ -37,6 +37,10 @@ def mean( l ):
return 0
return sum(l) / len(l)
def logit_normalization( logit, factor=1, eps=1.0e-7 ):
norms = torch.norm(logit, p=2, dim=-1, keepdim=True) + eps
return torch.div(logit, norms) / factor
# removes prefix from key in a dict
# useful for mapping args like ar_temperature => temperature
def convert_kwargs( kwargs, prefix ):