diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 9a00646..415df55 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -111,6 +111,8 @@ class AR_NAR(Base): ) # is NAR prev_list = resps_list + if max_levels == 0: + max_levels = self.n_resp_levels while True: level = prev_list[0].shape[-1] diff --git a/vall_e/models/base.py b/vall_e/models/base.py index fc8e646..01025f1 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -151,21 +151,43 @@ class MultiEmbedding(nn.Embedding): else: w = self.weight - padded_x_list = [] + padded_x_list = [] - for i, xi in enumerate(x_list): - xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k - wi = w.shape[0] - xi.shape[1] - xi = F.pad(xi, (0, 0, 0, wi)) # t l k - padded_x_list.append(xi.to(w)) + for i, xi in enumerate(x_list): + xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k + wi = w.shape[0] - xi.shape[1] + xi = F.pad(xi, (0, 0, 0, wi)) # t l k + padded_x_list.append(xi.to(w)) - x = torch.cat(padded_x_list) # n l k - x = einsum("l k d, n l k -> n d", w, x) + x = torch.cat(padded_x_list) # n l k + x = einsum("l k d, n l k -> n d", w, x) - x_list = x.split([*map(len, x_list)]) + x_list = x.split([*map(len, x_list)]) return x_list + """ + w_ar, w_nar = self.weight[:1], self.weight[1:] + p_ar_list, p_nar_list = [], [] + + for i, xi in enumerate(x_list): + if quant_levels is None or quant_levels[i] == 0: + w padded_x_list, = w_ar, p_ar_list + else: + w, padded_x_list = w_nar, p_nar_list + + # pad resp/prom tensor to fit weight + xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k + xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k + padded_x_list.append(xi.to(w)) + + # batch list => batch tensor + x_ar_list = einsum("l k d, n l k -> n d", w_ar, torch.cat(p_ar_list)) if len(p_ar_list) > 0 else [] + x_nar_list = einsum("l k d, n l k -> n d", w_nar, torch.cat(p_nar_list)) if len(p_nar_list) > 0 else [] + + x_list = x.split([*map(len, x_list)]) + """ + """ # Embedding that sums each RVQ-bin level within a given input acoustic prompt class PromEmbedding(nn.Module): diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 06ec095..88aeede 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -56,7 +56,7 @@ class NAR(Base): text_list: list[Tensor], proms_list: list[Tensor], resps_list: list[Tensor], - max_levels: int = 7, + max_levels: int = 0, sampling_temperature: float = 0.2, sampling_top_k: int = -100, sampling_top_p: float = 1.0, @@ -103,6 +103,8 @@ class NAR(Base): prev_list = [] else: prev_list = resps_list + if max_levels == 0: + max_levels = self.n_resp_levels while True: level = prev_list[0].shape[-1] - 1