more coping with the NAR len
This commit is contained in:
parent
11fa3da665
commit
d0a5c7eca2
|
@ -91,14 +91,14 @@ class MultiEmbedding(nn.Module):
|
|||
self.n_tokens = n_tokens
|
||||
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
||||
|
||||
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
|
||||
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resps_emb)
|
||||
# I imagine this is an oversight in the NAR.
|
||||
def forward(self, x_list: list[Tensor], quant_level: int | list[int] | Tensor | None = None) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
|
||||
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
|
||||
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
|
||||
# the NAR cannot share RVQ-bin level 0 with the AR for the resps_emb
|
||||
if self.monolithic:
|
||||
w = self.weight[:1] if quant_level is None or quant_level == 0 else self.weight[1:]
|
||||
else:
|
||||
|
@ -175,7 +175,8 @@ class AudioEmbedding(nn.Module):
|
|||
for i, embedding in enumerate(self.embeddings):
|
||||
embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape ))
|
||||
|
||||
def external_embeddings(self, input: Tensor) -> Tensor:
|
||||
def external_embeddings(self, input: Tensor, quant_level: int | None = None ) -> Tensor:
|
||||
if quant_level is None:
|
||||
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
|
||||
|
||||
# for AR, trim any stop tokens
|
||||
|
@ -212,7 +213,8 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return embedding
|
||||
|
||||
def internal_forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
|
||||
def internal_forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor:
|
||||
if quant_level is None:
|
||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||
|
||||
if self.sums and quant_level > 0:
|
||||
|
@ -223,11 +225,11 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
|
||||
x = self.internal_forward( xi, offset ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
|
||||
def forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor:
|
||||
x = self.internal_forward( xi, offset = offset, quant_level = quant_level ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
|
||||
|
||||
if self.external_mode and xi.shape[0] > 0:
|
||||
external_embeddings = self.external_embeddings( xi )
|
||||
external_embeddings = self.external_embeddings( xi, quant_level = quant_level )
|
||||
if self.external_mode == "exclusive":
|
||||
return external_embeddings
|
||||
x += external_embeddings
|
||||
|
@ -952,9 +954,15 @@ class Base(nn.Module):
|
|||
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
if self.version <= 4:
|
||||
return self.proms_emb( input if quant_level == 0 else input[:, :quant_level] )
|
||||
return self.proms_emb(
|
||||
input if quant_level == 0 else input[:, :quant_level]
|
||||
)
|
||||
|
||||
return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], offset = 0 )
|
||||
return self.proms_emb(
|
||||
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
offset = 0,
|
||||
)
|
||||
|
||||
# yuck
|
||||
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
|
||||
|
@ -972,6 +980,7 @@ class Base(nn.Module):
|
|||
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
|
||||
|
||||
task_type = "tts"
|
||||
input_prom = None
|
||||
for name, input in batch_input:
|
||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||
embedding = None
|
||||
|
@ -992,20 +1001,32 @@ class Base(nn.Module):
|
|||
embedding = self.langs_emb( input )
|
||||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
input_prom = torch.cat([ prom for prom in proms if isinstance(input, torch.Tensor) ])
|
||||
|
||||
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
||||
elif name == "tone" and self.tones_emb is not None:
|
||||
embedding = self.tones_emb( input )
|
||||
elif name == "resp":
|
||||
if "len" in self.capabilities and quant_level == 0:
|
||||
"""
|
||||
# fill with "stop" tokens for NAR-only model
|
||||
if input_prom is not None:
|
||||
# fill with the prom as the initial condition
|
||||
repeat = (input.shape[0] // input_prom.shape[0]) + 1
|
||||
repeated = input_prom[:, :1].repeat((repeat, 1))[:input.shape[0], :1]
|
||||
|
||||
embedding = self.resps_emb(
|
||||
torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token),
|
||||
offset = 0
|
||||
repeated,
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
"""
|
||||
# fill with filler tokens for NAR-only model
|
||||
embedding = self.dropout_token.repeat((input.shape[0], 1))
|
||||
else:
|
||||
# fill with "stop" token from the len layer for the NAR-only model
|
||||
embedding = self.resps_emb(
|
||||
# self.dropout_token.repeat((input.shape[0], 1)),
|
||||
torch.full_like(input if input.dim() == 1 else input[..., 0], 12),
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
|
||||
else:
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
if self.version <= 4:
|
||||
|
@ -1016,7 +1037,8 @@ class Base(nn.Module):
|
|||
else:
|
||||
embedding = self.resps_emb(
|
||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||
offset = 0 if quant_level == 0 or "len" in self.capabilities else 1
|
||||
offset = 1 if "len" in self.capabilities else (0 if quant_level == 0 else 1),
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
)
|
||||
|
||||
# apply token dropout
|
||||
|
|
Loading…
Reference in New Issue
Block a user