From 341e19162b0c91e8e067e678f7432f5be19f67dd Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 6 Sep 2024 11:41:41 -0500 Subject: [PATCH] fixes, again --- vall_e/data.py | 2 -- vall_e/models/ar_nar.py | 5 ++--- vall_e/models/base.py | 6 +++--- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index fa4c5de..8680619 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1055,8 +1055,6 @@ class Dataset(_Dataset): # Base STT ( => ) elif task == "stt": - # easier to just keep it instead of wrangling around trying to remove it - # it might also help to provide a guidance prompt but who knows right now proms = [ task ] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 574bc42..014e677 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -135,9 +135,8 @@ class AR_NAR(Base): for j, prom in enumerate( proms ): if not isinstance( prom, torch.Tensor ): continue - - if quant_level >= prom.shape[-1]: - quant_levels[i] = prom.shape[-1] - 1 + if quant_level >= prom.shape[-1]: + quant_levels[i] = prom.shape[-1] - 1 # apply token dropout error compensation if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 1607af8..e8dd287 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -280,7 +280,7 @@ class Classifiers(nn.Module): xi = [ #x if l == 0 else x if x.shape[-1] == max_size else - torch.cat( [ x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 ) + torch.cat( [x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 ) for x, l in zip(xi, levels) ] return torch.stack( xi ) @@ -1057,7 +1057,7 @@ class Base(nn.Module): embedding = self.langs_emb( input ) elif name == "prom": proms = [ input ] if isinstance(input, torch.Tensor) else input - input_prom = torch.cat([ prom for prom in proms if isinstance(input, torch.Tensor) ]) + input_prom = torch.cat([ prom for prom in proms if isinstance(prom, torch.Tensor) ]) embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] ) elif name == "tone" and self.tones_emb is not None: @@ -1164,7 +1164,7 @@ class Base(nn.Module): # list of tokens if not isinstance(input, torch.Tensor): - return sum( [ i.shape[0] for i in input if isinstance(i, torch.tensor) ] ) + 1 + return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1 # interleaved model if self.interleave and name == "resp":