tweaks
This commit is contained in:
parent
4afa4ccce5
commit
3f1070f575
|
@ -112,6 +112,7 @@ The model even working at all might entirely be a fluke.
|
||||||
A naive embedding implementation (`./vall_e/models/base.py`) manages to "just work" for EnCodec, while other audio codecs (DAC, `nvidia/audio-codec-44khz`) fail to converge meaningfully.
|
A naive embedding implementation (`./vall_e/models/base.py`) manages to "just work" for EnCodec, while other audio codecs (DAC, `nvidia/audio-codec-44khz`) fail to converge meaningfully.
|
||||||
|
|
||||||
A more codec-aware embedding/classifier implementation (`./vall_e/models/base_v2.py`) fails to properly learn all levels for any codec, even with all the additional cruft to help things. Even scaling the model up just has the gradients seem a little more chaotic with about the same training progression.
|
A more codec-aware embedding/classifier implementation (`./vall_e/models/base_v2.py`) fails to properly learn all levels for any codec, even with all the additional cruft to help things. Even scaling the model up just has the gradients seem a little more chaotic with about the same training progression.
|
||||||
|
* However it seems just giving it time will have things eventually sort itself out, maybe.
|
||||||
|
|
||||||
## Notices and Citations
|
## Notices and Citations
|
||||||
|
|
||||||
|
|
|
@ -751,7 +751,10 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||||
metadata = json_read( metadata_path )
|
try:
|
||||||
|
metadata = json_read( metadata_path )
|
||||||
|
except Exception as e:
|
||||||
|
return {}
|
||||||
|
|
||||||
if len(metadata) == 0:
|
if len(metadata) == 0:
|
||||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
||||||
|
|
|
@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint
|
||||||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||||
|
|
||||||
from .arch import *
|
from .arch import *
|
||||||
from ..utils import ml, clamp, mean
|
from ..utils import ml, clamp, mean, logit_normalization
|
||||||
from ..samplers import *
|
from ..samplers import *
|
||||||
|
|
||||||
# yuck, kind of needed
|
# yuck, kind of needed
|
||||||
|
@ -107,7 +107,8 @@ class ResidualAudioEncoder(nn.Module):
|
||||||
# ( seq_len, dim ) => ( seq_len, levels, dim )
|
# ( seq_len, dim ) => ( seq_len, levels, dim )
|
||||||
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
|
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
|
||||||
x = x + self.pos_embedding
|
x = x + self.pos_embedding
|
||||||
x, _ = self.cross_attn( x, x, x )
|
attn, _ = self.cross_attn( x, x, x )
|
||||||
|
x = x + attn
|
||||||
x = self.proj( x.mean(dim=1) )
|
x = self.proj( x.mean(dim=1) )
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
@ -135,7 +136,8 @@ class ResidualAudioDecoder(nn.Module):
|
||||||
def _forward( self, x: Tensor ) -> Tensor:
|
def _forward( self, x: Tensor ) -> Tensor:
|
||||||
seq_len, resp_levels = x.shape[0], len(self.projs)
|
seq_len, resp_levels = x.shape[0], len(self.projs)
|
||||||
x = torch.stack([proj(x) for proj in self.projs], dim=1)
|
x = torch.stack([proj(x) for proj in self.projs], dim=1)
|
||||||
x, _ = self.cross_attn( x, x, x )
|
attn, _ = self.cross_attn( x, x, x )
|
||||||
|
x = x + attn
|
||||||
x = self.head( x )
|
x = self.head( x )
|
||||||
x = x.view( resp_levels, seq_len, -1 )
|
x = x.view( resp_levels, seq_len, -1 )
|
||||||
return x
|
return x
|
||||||
|
@ -865,10 +867,6 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
return input
|
return input
|
||||||
|
|
||||||
def _logit_normalization( logit ):
|
|
||||||
norms = torch.norm(logit, p=2, dim=-1, keepdim=True) + 1e-7
|
|
||||||
return torch.div(logit, norms) / self.logit_normalization
|
|
||||||
|
|
||||||
def _calc_loss( logit, sequence, causal = True, level = None ):
|
def _calc_loss( logit, sequence, causal = True, level = None ):
|
||||||
# filter tokens that exceed the vocab size
|
# filter tokens that exceed the vocab size
|
||||||
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
||||||
|
@ -888,10 +886,10 @@ class Base_V2(nn.Module):
|
||||||
if self.logit_normalization:
|
if self.logit_normalization:
|
||||||
# it would probably be better to unsqueeze then squeeze to avoid code duplication but who cares
|
# it would probably be better to unsqueeze then squeeze to avoid code duplication but who cares
|
||||||
if not batched:
|
if not batched:
|
||||||
logit = _logit_normalization( logit )
|
logit = logit_normalization( logit, self.logit_normalization )
|
||||||
else:
|
else:
|
||||||
for i, l in enumerate( logit ):
|
for i, l in enumerate( logit ):
|
||||||
logit[i] = _logit_normalization( l )
|
logit[i] = logit_normalization( l, self.logit_normalization )
|
||||||
|
|
||||||
# flatten batch
|
# flatten batch
|
||||||
if batched:
|
if batched:
|
||||||
|
|
|
@ -18,4 +18,5 @@ from .utils import (
|
||||||
convert_kwargs,
|
convert_kwargs,
|
||||||
coerce_dtype,
|
coerce_dtype,
|
||||||
mean,
|
mean,
|
||||||
|
logit_normalization,
|
||||||
)
|
)
|
|
@ -37,6 +37,10 @@ def mean( l ):
|
||||||
return 0
|
return 0
|
||||||
return sum(l) / len(l)
|
return sum(l) / len(l)
|
||||||
|
|
||||||
|
def logit_normalization( logit, factor=1, eps=1.0e-7 ):
|
||||||
|
norms = torch.norm(logit, p=2, dim=-1, keepdim=True) + eps
|
||||||
|
return torch.div(logit, norms) / factor
|
||||||
|
|
||||||
# removes prefix from key in a dict
|
# removes prefix from key in a dict
|
||||||
# useful for mapping args like ar_temperature => temperature
|
# useful for mapping args like ar_temperature => temperature
|
||||||
def convert_kwargs( kwargs, prefix ):
|
def convert_kwargs( kwargs, prefix ):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user