prom_list -> proms_list
This commit is contained in:
parent
de59c04c50
commit
ea5e438fdb
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,2 @@
|
|||
from .ar import AR
|
||||
from .nar import NAR
|
|
@ -28,12 +28,12 @@ class AR(Base):
|
|||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
prom_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resp_list: list[Tensor],
|
||||
):
|
||||
return super().forward(
|
||||
text_list,
|
||||
prom_list,
|
||||
proms_list,
|
||||
resp_list,
|
||||
resp_list,
|
||||
quant_level=0,
|
||||
|
@ -44,7 +44,7 @@ class AR(Base):
|
|||
def generate(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
prom_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
max_steps: int = 1000,
|
||||
):
|
||||
device = text_list[0].device
|
||||
|
@ -53,7 +53,7 @@ class AR(Base):
|
|||
]
|
||||
stopped = [False] * len(text_list)
|
||||
for _ in trange(max_steps):
|
||||
r = super().forward(text_list, prom_list, resp_list)
|
||||
r = super().forward(text_list, proms_list, resp_list)
|
||||
for i, ri in enumerate(r):
|
||||
if ri.item() == self.stop_token:
|
||||
stopped[i] = True
|
||||
|
@ -65,7 +65,10 @@ class AR(Base):
|
|||
|
||||
|
||||
def example_usage():
|
||||
from functools import partial
|
||||
|
||||
import soundfile
|
||||
from einops import repeat
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
@ -79,9 +82,10 @@ def example_usage():
|
|||
torch.tensor([2, 3], device=device),
|
||||
]
|
||||
|
||||
prom_list = [
|
||||
torch.tensor([1, 2, 3], device=device),
|
||||
torch.tensor([2, 3], device=device),
|
||||
x8 = partial(repeat, pattern="t -> t q", q=8)
|
||||
proms_list = [
|
||||
x8(torch.tensor([1, 2, 3], device=device)),
|
||||
x8(torch.tensor([2, 3], device=device)),
|
||||
]
|
||||
|
||||
resp_list = [
|
||||
|
@ -91,7 +95,7 @@ def example_usage():
|
|||
|
||||
out = model.generate(
|
||||
text_list,
|
||||
prom_list,
|
||||
proms_list,
|
||||
max_steps=200,
|
||||
)
|
||||
|
||||
|
@ -101,7 +105,7 @@ def example_usage():
|
|||
|
||||
for i in range(100):
|
||||
optimizer.zero_grad()
|
||||
_ = model(text_list, prom_list, resp_list)
|
||||
_ = model(text_list, proms_list, resp_list)
|
||||
|
||||
losses = model.loss
|
||||
sum(losses.values()).backward()
|
||||
|
@ -110,7 +114,7 @@ def example_usage():
|
|||
if i % 20 == 0:
|
||||
print(f"iter={i}, {losses}.")
|
||||
|
||||
out = model.generate(text_list, prom_list, max_steps=200)
|
||||
out = model.generate(text_list, proms_list, max_steps=200)
|
||||
|
||||
print(qnt)
|
||||
print(out)
|
||||
|
|
|
@ -178,10 +178,28 @@ class Block(nn.Sequential):
|
|||
|
||||
|
||||
class Embedding(nn.Embedding):
|
||||
def forward(self, x: list[Tensor]) -> list[Tensor]:
|
||||
if len(x) == 0:
|
||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
return super().forward(torch.cat(x)).split([*map(len, x)])
|
||||
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
||||
|
||||
|
||||
class MultiEmbedding(nn.Module):
|
||||
def __init__(self, num_embeddings, embedding_dim, n_levels):
|
||||
super().__init__()
|
||||
self.n_levels = n_levels
|
||||
self.num_embeddings = num_embeddings
|
||||
self.emb = nn.Embedding(n_levels * num_embeddings, embedding_dim)
|
||||
|
||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
x = torch.cat(x_list)
|
||||
assert x.shape[1] == self.n_levels
|
||||
w = rearrange(self.emb.weight, "(q k) d -> q k d", q=self.n_levels)
|
||||
x = F.one_hot(x, num_classes=self.num_embeddings).float() # n q -> n q k
|
||||
x = einsum("q k d, n q k -> n d", w, x)
|
||||
return x.split([*map(len, x_list)])
|
||||
|
||||
|
||||
def _join(x: tuple[Tensor], sep: Tensor):
|
||||
|
@ -216,6 +234,8 @@ class Base(nn.Module):
|
|||
n_heads: int = 8,
|
||||
n_layers: int = 12,
|
||||
p_dropout: float = 0.1,
|
||||
n_prom_levels: int = 8,
|
||||
resp_loss_only: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_tokens = n_tokens
|
||||
|
@ -227,7 +247,10 @@ class Base(nn.Module):
|
|||
n_resp_tokens = n_tokens + n_stop_tokens
|
||||
|
||||
self.text_emb = Embedding(n_tokens, d_model)
|
||||
self.prom_emb = Embedding(n_tokens, d_model)
|
||||
|
||||
# It's not clear whether the whole prom are used or only the first level quantization
|
||||
# Just use all of them as it is more sufficient and we don't need to sample it, or do we?
|
||||
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=n_prom_levels)
|
||||
|
||||
# +1 to include the stop token
|
||||
self.resp_embs = nn.ModuleList(
|
||||
|
@ -243,6 +266,8 @@ class Base(nn.Module):
|
|||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
||||
self.resp_loss_only = resp_loss_only
|
||||
|
||||
@property
|
||||
def stop_token(self):
|
||||
if not self.use_stop_token:
|
||||
|
@ -265,7 +290,7 @@ class Base(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
prom_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resp_list: list[Tensor],
|
||||
targ_list: list[Tensor] | None = None,
|
||||
quant_level: int = 0,
|
||||
|
@ -278,7 +303,7 @@ class Base(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
prom_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resp_list: list[Tensor],
|
||||
targ_list: list[Tensor] | None = None,
|
||||
quant_level: int = 0,
|
||||
|
@ -290,7 +315,7 @@ class Base(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
prom_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resp_list: list[Tensor],
|
||||
targ_list: list[Tensor] | None = None,
|
||||
quant_level: int = 0,
|
||||
|
@ -300,7 +325,7 @@ class Base(nn.Module):
|
|||
"""
|
||||
Args:
|
||||
text_list: [t] * b
|
||||
prom_list: [t'] * b
|
||||
proms_list: [t' k] * b
|
||||
resp_list: [t''] * b, one quantization level only
|
||||
targ_list: [t''] * b, one quantization level only, when given, loss will be computed
|
||||
quant_level: specify which quant_level to feed forward, used in NAR mode.
|
||||
|
@ -311,7 +336,7 @@ class Base(nn.Module):
|
|||
"""
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.prom_emb(prom_list),
|
||||
self.prom_emb(proms_list),
|
||||
self.resp_embs[quant_level](resp_list),
|
||||
sep=self.sep,
|
||||
)
|
||||
|
@ -334,14 +359,21 @@ class Base(nn.Module):
|
|||
device = h.device
|
||||
|
||||
ignore_sep = torch.tensor(self.ignore_index, device=device)
|
||||
|
||||
# Predict the first level prom
|
||||
prom_list = [t[..., 0] for t in proms_list]
|
||||
text_prom_list = self._samplewise_merge_tensors(
|
||||
text_list, prom_list, sep=ignore_sep
|
||||
)
|
||||
|
||||
# Make every token earlier as it is future that is unknown
|
||||
# If we don't want compute loss, set all to ignored
|
||||
for i in range(len(text_prom_list)):
|
||||
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
|
||||
text_prom_list[i][-1] = self.ignore_index
|
||||
if self.resp_loss_only:
|
||||
text_prom_list[i][:] = self.ignore_index
|
||||
else:
|
||||
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
|
||||
text_prom_list[i][-1] = self.ignore_index
|
||||
|
||||
if shift_targ_list:
|
||||
# Also make target earlier if in autoregressive mode
|
||||
|
|
|
@ -21,7 +21,7 @@ class NAR(Base):
|
|||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
prom_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
*,
|
||||
resp_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
@ -29,7 +29,7 @@ class NAR(Base):
|
|||
"""
|
||||
Args:
|
||||
text_list: [t] * b
|
||||
prom_list: [t'] * b
|
||||
proms_list: [t' k] * b
|
||||
resp_list: [t'] * b, quants at level 0.
|
||||
resps_list: [t''] * b, 8 quantization levels for training.
|
||||
Returns:
|
||||
|
@ -52,7 +52,7 @@ class NAR(Base):
|
|||
for i in range(self.n_levels):
|
||||
hyp_resp_list = super().forward(
|
||||
text_list,
|
||||
prom_list,
|
||||
proms_list,
|
||||
hyp_resp_lists[-1],
|
||||
return_all_resp=True,
|
||||
shift_targ_list=False,
|
||||
|
@ -70,7 +70,7 @@ class NAR(Base):
|
|||
next_resp_list = [o[..., i + 1] for o in resps_list]
|
||||
hyp_resp_list = super().forward(
|
||||
text_list,
|
||||
prom_list,
|
||||
proms_list,
|
||||
resp_list,
|
||||
next_resp_list,
|
||||
return_all_resp=True,
|
||||
|
@ -90,7 +90,10 @@ class NAR(Base):
|
|||
|
||||
|
||||
def example_usage():
|
||||
from functools import partial
|
||||
|
||||
import soundfile
|
||||
from einops import repeat
|
||||
|
||||
from ..emb.qnt import decode
|
||||
from ..utils import gather_attribute
|
||||
|
@ -107,9 +110,10 @@ def example_usage():
|
|||
torch.tensor([2, 3], device=device),
|
||||
]
|
||||
|
||||
prom_list = [
|
||||
torch.tensor([1, 2, 3], device=device),
|
||||
torch.tensor([2, 3], device=device),
|
||||
x8 = partial(repeat, pattern="t -> t q", q=8)
|
||||
proms_list = [
|
||||
x8(torch.tensor([1, 2, 3], device=device)),
|
||||
x8(torch.tensor([2, 3], device=device)),
|
||||
]
|
||||
|
||||
resp_list = [
|
||||
|
@ -118,13 +122,11 @@ def example_usage():
|
|||
]
|
||||
|
||||
resps_list = [
|
||||
torch.tensor([1, 2, 3], device=device)
|
||||
.unsqueeze(-1)
|
||||
.repeat_interleave(8, dim=-1),
|
||||
x8(torch.tensor([1, 2, 3], device=device)),
|
||||
resps.t().to(device),
|
||||
]
|
||||
|
||||
out = model(text_list, prom_list, resp_list=resp_list)
|
||||
out = model(text_list, proms_list, resp_list=resp_list)
|
||||
codes = rearrange(out[1], "t k -> 1 k t")
|
||||
print(codes)
|
||||
wavs, sr = decode(codes)
|
||||
|
@ -134,7 +136,7 @@ def example_usage():
|
|||
|
||||
for i in range(100):
|
||||
optimizer.zero_grad()
|
||||
_ = model(text_list, prom_list, resps_list=resps_list)
|
||||
_ = model(text_list, proms_list, resps_list=resps_list)
|
||||
|
||||
losses = gather_attribute(model, "loss")
|
||||
loss = sum(losses.values())
|
||||
|
@ -146,7 +148,7 @@ def example_usage():
|
|||
stats["loss"] = loss.item()
|
||||
print(f"iter={i}, {stats}.")
|
||||
|
||||
out = model(text_list, prom_list, resp_list=resp_list)
|
||||
out = model(text_list, proms_list, resp_list=resp_list)
|
||||
codes = rearrange(out[1], "t k -> 1 k t")
|
||||
wavs, sr = decode(codes)
|
||||
soundfile.write("data/test/test.nar.recon.wav", wavs.cpu()[0, 0], sr)
|
||||
|
|
Loading…
Reference in New Issue
Block a user