more coping with the NAR len
This commit is contained in:
parent
11fa3da665
commit
d0a5c7eca2
|
@ -91,14 +91,14 @@ class MultiEmbedding(nn.Module):
|
||||||
self.n_tokens = n_tokens
|
self.n_tokens = n_tokens
|
||||||
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
||||||
|
|
||||||
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
|
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resps_emb)
|
||||||
# I imagine this is an oversight in the NAR.
|
# I imagine this is an oversight in the NAR.
|
||||||
def forward(self, x_list: list[Tensor], quant_level: int | list[int] | Tensor | None = None) -> list[Tensor]:
|
def forward(self, x_list: list[Tensor], quant_level: int | list[int] | Tensor | None = None) -> list[Tensor]:
|
||||||
if len(x_list) == 0:
|
if len(x_list) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
|
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
|
||||||
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
|
# the NAR cannot share RVQ-bin level 0 with the AR for the resps_emb
|
||||||
if self.monolithic:
|
if self.monolithic:
|
||||||
w = self.weight[:1] if quant_level is None or quant_level == 0 else self.weight[1:]
|
w = self.weight[:1] if quant_level is None or quant_level == 0 else self.weight[1:]
|
||||||
else:
|
else:
|
||||||
|
@ -175,8 +175,9 @@ class AudioEmbedding(nn.Module):
|
||||||
for i, embedding in enumerate(self.embeddings):
|
for i, embedding in enumerate(self.embeddings):
|
||||||
embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape ))
|
embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape ))
|
||||||
|
|
||||||
def external_embeddings(self, input: Tensor) -> Tensor:
|
def external_embeddings(self, input: Tensor, quant_level: int | None = None ) -> Tensor:
|
||||||
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
|
if quant_level is None:
|
||||||
|
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
|
||||||
|
|
||||||
# for AR, trim any stop tokens
|
# for AR, trim any stop tokens
|
||||||
has_stop_token = False
|
has_stop_token = False
|
||||||
|
@ -212,8 +213,9 @@ class AudioEmbedding(nn.Module):
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def internal_forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
|
def internal_forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor:
|
||||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
if quant_level is None:
|
||||||
|
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||||
|
|
||||||
if self.sums and quant_level > 0:
|
if self.sums and quant_level > 0:
|
||||||
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
||||||
|
@ -223,11 +225,11 @@ class AudioEmbedding(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
|
def forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor:
|
||||||
x = self.internal_forward( xi, offset ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
|
x = self.internal_forward( xi, offset = offset, quant_level = quant_level ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
|
||||||
|
|
||||||
if self.external_mode and xi.shape[0] > 0:
|
if self.external_mode and xi.shape[0] > 0:
|
||||||
external_embeddings = self.external_embeddings( xi )
|
external_embeddings = self.external_embeddings( xi, quant_level = quant_level )
|
||||||
if self.external_mode == "exclusive":
|
if self.external_mode == "exclusive":
|
||||||
return external_embeddings
|
return external_embeddings
|
||||||
x += external_embeddings
|
x += external_embeddings
|
||||||
|
@ -952,9 +954,15 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# get RVQ level 0, or up to targetted RVQ level inference
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
if self.version <= 4:
|
if self.version <= 4:
|
||||||
return self.proms_emb( input if quant_level == 0 else input[:, :quant_level] )
|
return self.proms_emb(
|
||||||
|
input if quant_level == 0 else input[:, :quant_level]
|
||||||
|
)
|
||||||
|
|
||||||
return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], offset = 0 )
|
return self.proms_emb(
|
||||||
|
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
|
||||||
|
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||||
|
offset = 0,
|
||||||
|
)
|
||||||
|
|
||||||
# yuck
|
# yuck
|
||||||
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
|
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
|
||||||
|
@ -972,6 +980,7 @@ class Base(nn.Module):
|
||||||
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
|
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
|
||||||
|
|
||||||
task_type = "tts"
|
task_type = "tts"
|
||||||
|
input_prom = None
|
||||||
for name, input in batch_input:
|
for name, input in batch_input:
|
||||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||||
embedding = None
|
embedding = None
|
||||||
|
@ -992,20 +1001,32 @@ class Base(nn.Module):
|
||||||
embedding = self.langs_emb( input )
|
embedding = self.langs_emb( input )
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
|
input_prom = torch.cat([ prom for prom in proms if isinstance(input, torch.Tensor) ])
|
||||||
|
|
||||||
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
||||||
elif name == "tone" and self.tones_emb is not None:
|
elif name == "tone" and self.tones_emb is not None:
|
||||||
embedding = self.tones_emb( input )
|
embedding = self.tones_emb( input )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
if "len" in self.capabilities and quant_level == 0:
|
if "len" in self.capabilities and quant_level == 0:
|
||||||
"""
|
if input_prom is not None:
|
||||||
# fill with "stop" tokens for NAR-only model
|
# fill with the prom as the initial condition
|
||||||
embedding = self.resps_emb(
|
repeat = (input.shape[0] // input_prom.shape[0]) + 1
|
||||||
torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token),
|
repeated = input_prom[:, :1].repeat((repeat, 1))[:input.shape[0], :1]
|
||||||
offset = 0
|
|
||||||
)
|
embedding = self.resps_emb(
|
||||||
"""
|
repeated,
|
||||||
# fill with filler tokens for NAR-only model
|
offset = 0,
|
||||||
embedding = self.dropout_token.repeat((input.shape[0], 1))
|
quant_level = 0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# fill with "stop" token from the len layer for the NAR-only model
|
||||||
|
embedding = self.resps_emb(
|
||||||
|
# self.dropout_token.repeat((input.shape[0], 1)),
|
||||||
|
torch.full_like(input if input.dim() == 1 else input[..., 0], 12),
|
||||||
|
offset = 0,
|
||||||
|
quant_level = 0,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# get RVQ level 0, or up to targetted RVQ level inference
|
# get RVQ level 0, or up to targetted RVQ level inference
|
||||||
if self.version <= 4:
|
if self.version <= 4:
|
||||||
|
@ -1016,7 +1037,8 @@ class Base(nn.Module):
|
||||||
else:
|
else:
|
||||||
embedding = self.resps_emb(
|
embedding = self.resps_emb(
|
||||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||||
offset = 0 if quant_level == 0 or "len" in self.capabilities else 1
|
offset = 1 if "len" in self.capabilities else (0 if quant_level == 0 else 1),
|
||||||
|
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply token dropout
|
# apply token dropout
|
||||||
|
|
Loading…
Reference in New Issue
Block a user