diff --git a/vall_e/models/base.py b/vall_e/models/base.py index fbf6426..cd41968 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -281,6 +281,27 @@ class AudioEmbedding(nn.Module): return x +# time-step embedding +# for the NAR-len, since it probably most likely requires encoding the timestep +class TimeEmbedding(nn.Module): + def __init__( + self, + d_model + ): + super().__init__() + self.emb = SinusoidalEmbedding(d_model) + self.mlp = nn.Sequential( + nn.Linear(d_model, d_model*4), + nn.SiLU(), + nn.Linear(d_model*4, d_model), + ) + + def forward( self, t ): + t = self.emb(t) + t = self.mlp(t) + + return t + # per-level classification # it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl class Classifiers(nn.Module): @@ -545,6 +566,7 @@ class Base(nn.Module): # experimental NAR-only mode self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None + self.time_emb = TimeEmbedding(d_model) if "len" in self.capabilities else None if attention_backend == "auto": attention_backend = "sdpa" @@ -967,6 +989,7 @@ class Base(nn.Module): tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, task_list: list[str] | None = None, + time_list: list[Tensor] | None = None, quant_levels: int | list[int] | Tensor | None = None ): @@ -1012,6 +1035,8 @@ class Base(nn.Module): t = random.random() p = math.cos(t * math.pi * 0.5) dropout_mask = _dropout_mask( resps_list[i], p=p ) + + inputs[i].append( ("timestep", torch.tensor(t, device=device) ) ) inputs[i].append( ("dropout_mask", dropout_mask ) ) # Audio length prediction task @@ -1108,16 +1133,14 @@ class Base(nn.Module): task_type = "tts" input_prom = None dropout_mask = None + timestep = None # pre-iterate for name, input in batch_input: - """ - if name == "prop": - proms = [ input ] if isinstance(input, torch.Tensor) else input - input_prom = torch.cat([ prom for prom in proms if isinstance(prom, torch.Tensor) ]) - """ if name == "dropout_mask": dropout_mask = input + elif name == "timestep": + timestep = input for name, input in batch_input: # technically can provide a map for input_name => embedding, but some embedding requires additional processing @@ -1160,6 +1183,9 @@ class Base(nn.Module): offset = 0, quant_level = 0, ) + + t_emb = self.time_emb( timestep ) + embedding += t_emb # cheat-y way to handle performing STT across all levels elif task_type in summed_embeddings_task: # we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT...... @@ -1236,7 +1262,7 @@ class Base(nn.Module): return 1 # a mask - if name == "dropout_mask": + if name in ["dropout_mask", "timestep"]: return 0 # list of tokens diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 30a3702..67700c2 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -298,6 +298,7 @@ class NAR(Base): resps_list=resps_list, lang_list=lang_list, tone_list=tone_list, + time_list=[ timestep ], quant_levels=quant_levels, ) output = _super.forward(