oops
This commit is contained in:
parent
5d66a7db52
commit
6a967f91b9
|
@ -223,7 +223,7 @@ class AudioEmbedding(nn.Module):
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor:
|
def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor:
|
||||||
if offset is None:
|
if offset is None:
|
||||||
# prom
|
# prom
|
||||||
if self.capabilities is None:
|
if self.capabilities is None:
|
||||||
|
@ -236,6 +236,9 @@ class AudioEmbedding(nn.Module):
|
||||||
elif quant_level > 0:
|
elif quant_level > 0:
|
||||||
offset = 1
|
offset = 1
|
||||||
|
|
||||||
|
if sums is None:
|
||||||
|
sums = self.sums
|
||||||
|
|
||||||
if quant_level is None:
|
if quant_level is None:
|
||||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||||
|
|
||||||
|
@ -247,8 +250,8 @@ class AudioEmbedding(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None ) -> Tensor:
|
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor:
|
||||||
x = self.internal_forward( xi, offset = offset, quant_level = quant_level ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
|
x = self.internal_forward( xi, offset = offset, quant_level = quant_level, sums = sums ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
|
||||||
|
|
||||||
if self.external_mode and xi.shape[0] > 0:
|
if self.external_mode and xi.shape[0] > 0:
|
||||||
external_embeddings = self.external_embeddings( xi, quant_level = quant_level )
|
external_embeddings = self.external_embeddings( xi, quant_level = quant_level )
|
||||||
|
@ -1093,10 +1096,12 @@ class Base(nn.Module):
|
||||||
)
|
)
|
||||||
# cheat-y way to handle performing STT across all levels
|
# cheat-y way to handle performing STT across all levels
|
||||||
elif task_type in summed_embeddings_task:
|
elif task_type in summed_embeddings_task:
|
||||||
|
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
||||||
embedding = sum([ self.resps_emb(
|
embedding = sum([ self.resps_emb(
|
||||||
input[:, :l+1],
|
input[:, :l+1],
|
||||||
offset = 0 if l == 0 else 1, # or maybe set to 1
|
offset = 0 if l == 0 else 1, # or maybe set to 1
|
||||||
quant_level = l
|
quant_level = l,
|
||||||
|
sums = False
|
||||||
) for l in range( input.shape[-1] - 1 ) ])
|
) for l in range( input.shape[-1] - 1 ) ])
|
||||||
else:
|
else:
|
||||||
# get RVQ level 0, or up to targetted RVQ level inference
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
|
|
|
@ -226,13 +226,11 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
args.references = args.references.split(";") if args.references is not None else []
|
args.references = args.references.split(";") if args.references is not None else []
|
||||||
if args.max_ar_steps == 0:
|
if args.max_ar_steps == 0:
|
||||||
for i, path in enumerate( args.references ):
|
for i, path in enumerate( args.references ):
|
||||||
print(i, path)
|
|
||||||
metadata = torchaudio.info(path)
|
metadata = torchaudio.info(path)
|
||||||
duration = metadata.num_frames / metadata.sample_rate
|
duration = metadata.num_frames / metadata.sample_rate
|
||||||
args.max_ar_steps += duration
|
args.max_ar_steps += duration
|
||||||
args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second
|
args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second
|
||||||
|
|
||||||
|
|
||||||
tts = init_tts()
|
tts = init_tts()
|
||||||
|
|
||||||
gr.Info("Inferencing...")
|
gr.Info("Inferencing...")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user