set default max_levels for NAR to 0 and implicitly set it to max resps levels because the previous way was implicitly assuming all models were outputting at 1+7 RVQ bins.

This commit is contained in:
mrq 2023-09-10 20:33:33 -05:00
parent 671dca88ee
commit a1f250ffac
3 changed files with 36 additions and 10 deletions

View File

@ -111,6 +111,8 @@ class AR_NAR(Base):
)
# is NAR
prev_list = resps_list
if max_levels == 0:
max_levels = self.n_resp_levels
while True:
level = prev_list[0].shape[-1]

View File

@ -151,21 +151,43 @@ class MultiEmbedding(nn.Embedding):
else:
w = self.weight
padded_x_list = []
padded_x_list = []
for i, xi in enumerate(x_list):
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
wi = w.shape[0] - xi.shape[1]
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
padded_x_list.append(xi.to(w))
for i, xi in enumerate(x_list):
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
wi = w.shape[0] - xi.shape[1]
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
padded_x_list.append(xi.to(w))
x = torch.cat(padded_x_list) # n l k
x = einsum("l k d, n l k -> n d", w, x)
x = torch.cat(padded_x_list) # n l k
x = einsum("l k d, n l k -> n d", w, x)
x_list = x.split([*map(len, x_list)])
x_list = x.split([*map(len, x_list)])
return x_list
"""
w_ar, w_nar = self.weight[:1], self.weight[1:]
p_ar_list, p_nar_list = [], []
for i, xi in enumerate(x_list):
if quant_levels is None or quant_levels[i] == 0:
w padded_x_list, = w_ar, p_ar_list
else:
w, padded_x_list = w_nar, p_nar_list
# pad resp/prom tensor to fit weight
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k
padded_x_list.append(xi.to(w))
# batch list => batch tensor
x_ar_list = einsum("l k d, n l k -> n d", w_ar, torch.cat(p_ar_list)) if len(p_ar_list) > 0 else []
x_nar_list = einsum("l k d, n l k -> n d", w_nar, torch.cat(p_nar_list)) if len(p_nar_list) > 0 else []
x_list = x.split([*map(len, x_list)])
"""
"""
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class PromEmbedding(nn.Module):

View File

@ -56,7 +56,7 @@ class NAR(Base):
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
max_levels: int = 7,
max_levels: int = 0,
sampling_temperature: float = 0.2,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
@ -103,6 +103,8 @@ class NAR(Base):
prev_list = []
else:
prev_list = resps_list
if max_levels == 0:
max_levels = self.n_resp_levels
while True:
level = prev_list[0].shape[-1] - 1