set default max_levels for NAR to 0 and implicitly set it to max resps levels because the previous way was implicitly assuming all models were outputting at 1+7 RVQ bins.
This commit is contained in:
parent
671dca88ee
commit
a1f250ffac
|
@ -111,6 +111,8 @@ class AR_NAR(Base):
|
|||
)
|
||||
# is NAR
|
||||
prev_list = resps_list
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_resp_levels
|
||||
|
||||
while True:
|
||||
level = prev_list[0].shape[-1]
|
||||
|
|
|
@ -151,21 +151,43 @@ class MultiEmbedding(nn.Embedding):
|
|||
else:
|
||||
w = self.weight
|
||||
|
||||
padded_x_list = []
|
||||
padded_x_list = []
|
||||
|
||||
for i, xi in enumerate(x_list):
|
||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
||||
wi = w.shape[0] - xi.shape[1]
|
||||
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
|
||||
padded_x_list.append(xi.to(w))
|
||||
for i, xi in enumerate(x_list):
|
||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
||||
wi = w.shape[0] - xi.shape[1]
|
||||
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
|
||||
padded_x_list.append(xi.to(w))
|
||||
|
||||
x = torch.cat(padded_x_list) # n l k
|
||||
x = einsum("l k d, n l k -> n d", w, x)
|
||||
x = torch.cat(padded_x_list) # n l k
|
||||
x = einsum("l k d, n l k -> n d", w, x)
|
||||
|
||||
x_list = x.split([*map(len, x_list)])
|
||||
x_list = x.split([*map(len, x_list)])
|
||||
|
||||
return x_list
|
||||
|
||||
"""
|
||||
w_ar, w_nar = self.weight[:1], self.weight[1:]
|
||||
p_ar_list, p_nar_list = [], []
|
||||
|
||||
for i, xi in enumerate(x_list):
|
||||
if quant_levels is None or quant_levels[i] == 0:
|
||||
w padded_x_list, = w_ar, p_ar_list
|
||||
else:
|
||||
w, padded_x_list = w_nar, p_nar_list
|
||||
|
||||
# pad resp/prom tensor to fit weight
|
||||
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
|
||||
xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k
|
||||
padded_x_list.append(xi.to(w))
|
||||
|
||||
# batch list => batch tensor
|
||||
x_ar_list = einsum("l k d, n l k -> n d", w_ar, torch.cat(p_ar_list)) if len(p_ar_list) > 0 else []
|
||||
x_nar_list = einsum("l k d, n l k -> n d", w_nar, torch.cat(p_nar_list)) if len(p_nar_list) > 0 else []
|
||||
|
||||
x_list = x.split([*map(len, x_list)])
|
||||
"""
|
||||
|
||||
"""
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
class PromEmbedding(nn.Module):
|
||||
|
|
|
@ -56,7 +56,7 @@ class NAR(Base):
|
|||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor],
|
||||
max_levels: int = 7,
|
||||
max_levels: int = 0,
|
||||
sampling_temperature: float = 0.2,
|
||||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
|
@ -103,6 +103,8 @@ class NAR(Base):
|
|||
prev_list = []
|
||||
else:
|
||||
prev_list = resps_list
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_resp_levels
|
||||
|
||||
while True:
|
||||
level = prev_list[0].shape[-1] - 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user