oops
This commit is contained in:
parent
5d66a7db52
commit
6a967f91b9
|
@ -223,7 +223,7 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return embedding
|
||||
|
||||
def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor:
|
||||
def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor:
|
||||
if offset is None:
|
||||
# prom
|
||||
if self.capabilities is None:
|
||||
|
@ -236,6 +236,9 @@ class AudioEmbedding(nn.Module):
|
|||
elif quant_level > 0:
|
||||
offset = 1
|
||||
|
||||
if sums is None:
|
||||
sums = self.sums
|
||||
|
||||
if quant_level is None:
|
||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||
|
||||
|
@ -247,8 +250,8 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor:
|
||||
x = self.internal_forward( xi, offset = offset, quant_level = quant_level ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
|
||||
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor:
|
||||
x = self.internal_forward( xi, offset = offset, quant_level = quant_level, sums = sums ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
|
||||
|
||||
if self.external_mode and xi.shape[0] > 0:
|
||||
external_embeddings = self.external_embeddings( xi, quant_level = quant_level )
|
||||
|
@ -1093,10 +1096,12 @@ class Base(nn.Module):
|
|||
)
|
||||
# cheat-y way to handle performing STT across all levels
|
||||
elif task_type in summed_embeddings_task:
|
||||
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
||||
embedding = sum([ self.resps_emb(
|
||||
input[:, :l+1],
|
||||
offset = 0 if l == 0 else 1, # or maybe set to 1
|
||||
quant_level = l
|
||||
quant_level = l,
|
||||
sums = False
|
||||
) for l in range( input.shape[-1] - 1 ) ])
|
||||
else:
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
|
|
|
@ -226,13 +226,11 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
args.references = args.references.split(";") if args.references is not None else []
|
||||
if args.max_ar_steps == 0:
|
||||
for i, path in enumerate( args.references ):
|
||||
print(i, path)
|
||||
metadata = torchaudio.info(path)
|
||||
duration = metadata.num_frames / metadata.sample_rate
|
||||
args.max_ar_steps += duration
|
||||
args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second
|
||||
|
||||
|
||||
tts = init_tts()
|
||||
|
||||
gr.Info("Inferencing...")
|
||||
|
|
Loading…
Reference in New Issue
Block a user