tweaks
This commit is contained in:
parent
4afa4ccce5
commit
3f1070f575
|
@ -112,6 +112,7 @@ The model even working at all might entirely be a fluke.
|
|||
A naive embedding implementation (`./vall_e/models/base.py`) manages to "just work" for EnCodec, while other audio codecs (DAC, `nvidia/audio-codec-44khz`) fail to converge meaningfully.
|
||||
|
||||
A more codec-aware embedding/classifier implementation (`./vall_e/models/base_v2.py`) fails to properly learn all levels for any codec, even with all the additional cruft to help things. Even scaling the model up just has the gradients seem a little more chaotic with about the same training progression.
|
||||
* However it seems just giving it time will have things eventually sort itself out, maybe.
|
||||
|
||||
## Notices and Citations
|
||||
|
||||
|
|
|
@ -751,7 +751,10 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
metadata = {}
|
||||
|
||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||
metadata = json_read( metadata_path )
|
||||
try:
|
||||
metadata = json_read( metadata_path )
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
if len(metadata) == 0:
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||
|
||||
from .arch import *
|
||||
from ..utils import ml, clamp, mean
|
||||
from ..utils import ml, clamp, mean, logit_normalization
|
||||
from ..samplers import *
|
||||
|
||||
# yuck, kind of needed
|
||||
|
@ -107,7 +107,8 @@ class ResidualAudioEncoder(nn.Module):
|
|||
# ( seq_len, dim ) => ( seq_len, levels, dim )
|
||||
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
|
||||
x = x + self.pos_embedding
|
||||
x, _ = self.cross_attn( x, x, x )
|
||||
attn, _ = self.cross_attn( x, x, x )
|
||||
x = x + attn
|
||||
x = self.proj( x.mean(dim=1) )
|
||||
|
||||
return x
|
||||
|
@ -135,7 +136,8 @@ class ResidualAudioDecoder(nn.Module):
|
|||
def _forward( self, x: Tensor ) -> Tensor:
|
||||
seq_len, resp_levels = x.shape[0], len(self.projs)
|
||||
x = torch.stack([proj(x) for proj in self.projs], dim=1)
|
||||
x, _ = self.cross_attn( x, x, x )
|
||||
attn, _ = self.cross_attn( x, x, x )
|
||||
x = x + attn
|
||||
x = self.head( x )
|
||||
x = x.view( resp_levels, seq_len, -1 )
|
||||
return x
|
||||
|
@ -865,10 +867,6 @@ class Base_V2(nn.Module):
|
|||
|
||||
return input
|
||||
|
||||
def _logit_normalization( logit ):
|
||||
norms = torch.norm(logit, p=2, dim=-1, keepdim=True) + 1e-7
|
||||
return torch.div(logit, norms) / self.logit_normalization
|
||||
|
||||
def _calc_loss( logit, sequence, causal = True, level = None ):
|
||||
# filter tokens that exceed the vocab size
|
||||
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
||||
|
@ -888,10 +886,10 @@ class Base_V2(nn.Module):
|
|||
if self.logit_normalization:
|
||||
# it would probably be better to unsqueeze then squeeze to avoid code duplication but who cares
|
||||
if not batched:
|
||||
logit = _logit_normalization( logit )
|
||||
logit = logit_normalization( logit, self.logit_normalization )
|
||||
else:
|
||||
for i, l in enumerate( logit ):
|
||||
logit[i] = _logit_normalization( l )
|
||||
logit[i] = logit_normalization( l, self.logit_normalization )
|
||||
|
||||
# flatten batch
|
||||
if batched:
|
||||
|
|
|
@ -18,4 +18,5 @@ from .utils import (
|
|||
convert_kwargs,
|
||||
coerce_dtype,
|
||||
mean,
|
||||
logit_normalization,
|
||||
)
|
|
@ -37,6 +37,10 @@ def mean( l ):
|
|||
return 0
|
||||
return sum(l) / len(l)
|
||||
|
||||
def logit_normalization( logit, factor=1, eps=1.0e-7 ):
|
||||
norms = torch.norm(logit, p=2, dim=-1, keepdim=True) + eps
|
||||
return torch.div(logit, norms) / factor
|
||||
|
||||
# removes prefix from key in a dict
|
||||
# useful for mapping args like ar_temperature => temperature
|
||||
def convert_kwargs( kwargs, prefix ):
|
||||
|
|
Loading…
Reference in New Issue
Block a user