more coping with the NAR len

This commit is contained in:
mrq 2024-08-03 20:23:36 -05:00
parent 11fa3da665
commit d0a5c7eca2

View File

@ -91,14 +91,14 @@ class MultiEmbedding(nn.Module):
self.n_tokens = n_tokens
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resps_emb)
# I imagine this is an oversight in the NAR.
def forward(self, x_list: list[Tensor], quant_level: int | list[int] | Tensor | None = None) -> list[Tensor]:
if len(x_list) == 0:
return []
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
# the NAR cannot share RVQ-bin level 0 with the AR for the resps_emb
if self.monolithic:
w = self.weight[:1] if quant_level is None or quant_level == 0 else self.weight[1:]
else:
@ -175,8 +175,9 @@ class AudioEmbedding(nn.Module):
for i, embedding in enumerate(self.embeddings):
embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape ))
def external_embeddings(self, input: Tensor) -> Tensor:
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
def external_embeddings(self, input: Tensor, quant_level: int | None = None ) -> Tensor:
if quant_level is None:
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
# for AR, trim any stop tokens
has_stop_token = False
@ -212,8 +213,9 @@ class AudioEmbedding(nn.Module):
return embedding
def internal_forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
def internal_forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor:
if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
if self.sums and quant_level > 0:
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
@ -223,11 +225,11 @@ class AudioEmbedding(nn.Module):
return x
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
x = self.internal_forward( xi, offset ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
def forward(self, xi: Tensor, offset: int = 0, quant_level: int | None = None ) -> Tensor:
x = self.internal_forward( xi, offset = offset, quant_level = quant_level ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
if self.external_mode and xi.shape[0] > 0:
external_embeddings = self.external_embeddings( xi )
external_embeddings = self.external_embeddings( xi, quant_level = quant_level )
if self.external_mode == "exclusive":
return external_embeddings
x += external_embeddings
@ -952,9 +954,15 @@ class Base(nn.Module):
# get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4:
return self.proms_emb( input if quant_level == 0 else input[:, :quant_level] )
return self.proms_emb(
input if quant_level == 0 else input[:, :quant_level]
)
return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], offset = 0 )
return self.proms_emb(
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
offset = 0,
)
# yuck
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
@ -972,6 +980,7 @@ class Base(nn.Module):
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
task_type = "tts"
input_prom = None
for name, input in batch_input:
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
embedding = None
@ -992,20 +1001,32 @@ class Base(nn.Module):
embedding = self.langs_emb( input )
elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input
input_prom = torch.cat([ prom for prom in proms if isinstance(input, torch.Tensor) ])
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
if "len" in self.capabilities and quant_level == 0:
"""
# fill with "stop" tokens for NAR-only model
embedding = self.resps_emb(
torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token),
offset = 0
)
"""
# fill with filler tokens for NAR-only model
embedding = self.dropout_token.repeat((input.shape[0], 1))
if input_prom is not None:
# fill with the prom as the initial condition
repeat = (input.shape[0] // input_prom.shape[0]) + 1
repeated = input_prom[:, :1].repeat((repeat, 1))[:input.shape[0], :1]
embedding = self.resps_emb(
repeated,
offset = 0,
quant_level = 0,
)
else:
# fill with "stop" token from the len layer for the NAR-only model
embedding = self.resps_emb(
# self.dropout_token.repeat((input.shape[0], 1)),
torch.full_like(input if input.dim() == 1 else input[..., 0], 12),
offset = 0,
quant_level = 0,
)
else:
# get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4:
@ -1016,7 +1037,8 @@ class Base(nn.Module):
else:
embedding = self.resps_emb(
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
offset = 0 if quant_level == 0 or "len" in self.capabilities else 1
offset = 1 if "len" in self.capabilities else (0 if quant_level == 0 else 1),
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
)
# apply token dropout