diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 9cf3796..533c678 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -223,7 +223,7 @@ class AudioEmbedding(nn.Module): return embedding - def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor: + def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor: if offset is None: # prom if self.capabilities is None: @@ -236,6 +236,9 @@ class AudioEmbedding(nn.Module): elif quant_level > 0: offset = 1 + if sums is None: + sums = self.sums + if quant_level is None: quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 @@ -247,8 +250,8 @@ class AudioEmbedding(nn.Module): return x - def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor: - x = self.internal_forward( xi, offset = offset, quant_level = quant_level ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None + def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor: + x = self.internal_forward( xi, offset = offset, quant_level = quant_level, sums = sums ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None if self.external_mode and xi.shape[0] > 0: external_embeddings = self.external_embeddings( xi, quant_level = quant_level ) @@ -1093,10 +1096,12 @@ class Base(nn.Module): ) # cheat-y way to handle performing STT across all levels elif task_type in summed_embeddings_task: + # we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT...... embedding = sum([ self.resps_emb( input[:, :l+1], offset = 0 if l == 0 else 1, # or maybe set to 1 - quant_level = l + quant_level = l, + sums = False ) for l in range( input.shape[-1] - 1 ) ]) else: # get RVQ level 0, or up to targetted RVQ level inference diff --git a/vall_e/webui.py b/vall_e/webui.py index 3a143b4..dc7ea97 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -226,13 +226,11 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): args.references = args.references.split(";") if args.references is not None else [] if args.max_ar_steps == 0: for i, path in enumerate( args.references ): - print(i, path) metadata = torchaudio.info(path) duration = metadata.num_frames / metadata.sample_rate args.max_ar_steps += duration args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second - tts = init_tts() gr.Info("Inferencing...")