This commit is contained in:
mrq 2024-09-07 22:13:49 -05:00
parent 5d66a7db52
commit 6a967f91b9
2 changed files with 9 additions and 6 deletions

View File

@ -223,7 +223,7 @@ class AudioEmbedding(nn.Module):
return embedding
def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor:
def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor:
if offset is None:
# prom
if self.capabilities is None:
@ -236,6 +236,9 @@ class AudioEmbedding(nn.Module):
elif quant_level > 0:
offset = 1
if sums is None:
sums = self.sums
if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
@ -247,8 +250,8 @@ class AudioEmbedding(nn.Module):
return x
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor:
x = self.internal_forward( xi, offset = offset, quant_level = quant_level ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor:
x = self.internal_forward( xi, offset = offset, quant_level = quant_level, sums = sums ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
if self.external_mode and xi.shape[0] > 0:
external_embeddings = self.external_embeddings( xi, quant_level = quant_level )
@ -1093,10 +1096,12 @@ class Base(nn.Module):
)
# cheat-y way to handle performing STT across all levels
elif task_type in summed_embeddings_task:
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
embedding = sum([ self.resps_emb(
input[:, :l+1],
offset = 0 if l == 0 else 1, # or maybe set to 1
quant_level = l
quant_level = l,
sums = False
) for l in range( input.shape[-1] - 1 ) ])
else:
# get RVQ level 0, or up to targetted RVQ level inference

View File

@ -226,13 +226,11 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
args.references = args.references.split(";") if args.references is not None else []
if args.max_ar_steps == 0:
for i, path in enumerate( args.references ):
print(i, path)
metadata = torchaudio.info(path)
duration = metadata.num_frames / metadata.sample_rate
args.max_ar_steps += duration
args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second
tts = init_tts()
gr.Info("Inferencing...")