AdaLN and sample-wise quant level sampling
This commit is contained in:
parent
bead906d72
commit
36e8894f5c
@ -42,4 +42,6 @@ python -m vall_e.train yaml=config/your_data/ar_or_nar.yml
|
||||
- [x] Audio decoding from tokens
|
||||
- [x] NAR model for the rest quantizers
|
||||
- [x] Trainers for both models
|
||||
- [x] Implement AdaLN for NAR model.
|
||||
- [x] Sample-wise quantization level sampling for NAR training.
|
||||
- [ ] Pre-trained checkpoint and demos on LibriTTS
|
||||
|
@ -4,3 +4,4 @@ spkr_name_getter: "lambda p: p.parts[-3]"
|
||||
model: ar-quarter
|
||||
batch_size: 8
|
||||
eval_batch_size: 8
|
||||
eval_every: 10_000
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -8,7 +8,7 @@ from .base import Base
|
||||
|
||||
class AR(Base):
|
||||
@property
|
||||
def n_levels(self):
|
||||
def n_resp_levels(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
@ -19,6 +19,10 @@ class AR(Base):
|
||||
def use_stop_token(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def norm_type(self):
|
||||
return "ln"
|
||||
|
||||
def _prune(self, l: Tensor):
|
||||
indices = (l == self.stop_token).nonzero()
|
||||
if len(indices) == 0:
|
||||
@ -38,7 +42,7 @@ class AR(Base):
|
||||
proms_list,
|
||||
resp_list,
|
||||
resp_list,
|
||||
quant_level=0,
|
||||
quant_levels=None,
|
||||
shift_targ_list=True,
|
||||
return_all_resp=False,
|
||||
)
|
||||
@ -55,14 +59,13 @@ class AR(Base):
|
||||
resp_list: list[Tensor] = [
|
||||
torch.zeros(0, device=device).long() for _ in text_list
|
||||
]
|
||||
stopped = [False] * len(text_list)
|
||||
stopped = torch.zeros(len(text_list), device=device).bool()
|
||||
for _ in trange(max_steps):
|
||||
r = super().forward(text_list, proms_list, resp_list)
|
||||
stopped |= r == self.stop_token
|
||||
for i, ri in enumerate(r):
|
||||
if ri.item() == self.stop_token:
|
||||
stopped[i] = True
|
||||
resp_list[i] = torch.cat([resp_list[i], ri[None]])
|
||||
if all(stopped):
|
||||
if stopped.all().item():
|
||||
break
|
||||
pruned = [self._prune(r) for r in resp_list]
|
||||
return pruned
|
||||
|
@ -4,7 +4,7 @@ from typing import Literal, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor, einsum, nn
|
||||
from torch.distributions import Categorical
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
@ -89,12 +89,12 @@ class SinusodialEmbedding(nn.Module):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, d_model, num_heads, casual):
|
||||
def __init__(self, d_model, n_heads, casual):
|
||||
super().__init__()
|
||||
assert d_model % num_heads == 0
|
||||
dim_head = d_model // num_heads
|
||||
assert d_model % n_heads == 0
|
||||
dim_head = d_model // n_heads
|
||||
self.casual = casual
|
||||
self.num_heads = num_heads
|
||||
self.n_heads = n_heads
|
||||
self.scale = dim_head**-0.5
|
||||
self.to_qkv = nn.Linear(d_model, d_model * 3, bias=False)
|
||||
self.to_out = nn.Linear(d_model, d_model)
|
||||
@ -107,7 +107,7 @@ class Attention(nn.Module):
|
||||
Returns:
|
||||
x: (b t c)
|
||||
"""
|
||||
h = self.num_heads
|
||||
h = self.n_heads
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, "b t (h d) -> b t h d", h=h), (q, k, v))
|
||||
@ -132,52 +132,91 @@ class Attention(nn.Module):
|
||||
return o
|
||||
|
||||
|
||||
class AdaLN(nn.Module):
|
||||
def __init__(self, d_model, n_levels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.emb = nn.Embedding(n_levels, d_model * 2)
|
||||
nn.init.zeros_(self.emb.weight)
|
||||
|
||||
def forward(self, x, l):
|
||||
logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1)
|
||||
h = F.layer_norm(x, x.shape[-1:], eps=self.eps)
|
||||
y = logγ.exp() * h + β
|
||||
return y
|
||||
|
||||
|
||||
class PrenormResidual(nn.Module):
|
||||
def __init__(self, block, d_model, dropout, requires_mask=False):
|
||||
def __init__(
|
||||
self,
|
||||
block,
|
||||
d_model,
|
||||
p_dropout,
|
||||
requires_mask=False,
|
||||
norm_type="ln",
|
||||
n_levels: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.block = block
|
||||
self.requires_mask = requires_mask
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.norm_type = norm_type
|
||||
if norm_type == "ln":
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
elif norm_type == "adaln":
|
||||
assert n_levels is not None
|
||||
self.norm = AdaLN(d_model, n_levels)
|
||||
self.dropout = nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, m):
|
||||
opts = {"m": m} if self.requires_mask else {}
|
||||
x = x + self.dropout(self.block(self.norm(x) * m, **opts))
|
||||
def forward(self, x, m, l):
|
||||
"""
|
||||
Args:
|
||||
x: input (b t d)
|
||||
m: mask (b t 1), 1 is valuable and 0 is padding
|
||||
l: level to use, required only for AdaLN
|
||||
"""
|
||||
nopts = {"l": l} if self.norm_type == "adaln" else {}
|
||||
bopts = {"m": m} if self.requires_mask else {}
|
||||
x = x + self.dropout(self.block(self.norm(x, **nopts) * m, **bopts))
|
||||
return x * m
|
||||
|
||||
|
||||
class Block(nn.Sequential):
|
||||
def __init__(self, d_model, num_heads, dropout, casual):
|
||||
def __init__(self, d_model, n_heads, p_dropout, casual, norm_type, n_levels):
|
||||
super().__init__()
|
||||
self.attn = PrenormResidual(
|
||||
Attention(d_model, num_heads, casual),
|
||||
Attention(d_model, n_heads, casual),
|
||||
d_model=d_model,
|
||||
dropout=dropout,
|
||||
p_dropout=p_dropout,
|
||||
requires_mask=True,
|
||||
norm_type=norm_type,
|
||||
n_levels=n_levels,
|
||||
)
|
||||
self.ffn = PrenormResidual(
|
||||
nn.Sequential(
|
||||
nn.Linear(d_model, d_model * 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Dropout(p_dropout),
|
||||
nn.Linear(d_model * 4, d_model),
|
||||
),
|
||||
d_model=d_model,
|
||||
dropout=dropout,
|
||||
p_dropout=p_dropout,
|
||||
norm_type=norm_type,
|
||||
n_levels=n_levels,
|
||||
)
|
||||
|
||||
def forward(self, x, m):
|
||||
def forward(self, x, m, l):
|
||||
"""
|
||||
Args:
|
||||
x: (b t c)
|
||||
m: (b t 1)
|
||||
l: (b)
|
||||
"""
|
||||
poor_in_vram = True
|
||||
if x.requires_grad and poor_in_vram:
|
||||
x = checkpoint(self.attn, x, m)
|
||||
x = checkpoint(self.attn, x, m, l)
|
||||
else:
|
||||
x = self.attn(x, m)
|
||||
x = self.ffn(x, m)
|
||||
x = self.attn(x, m, l)
|
||||
x = self.ffn(x, m, l)
|
||||
return x
|
||||
|
||||
|
||||
@ -188,22 +227,57 @@ class Embedding(nn.Embedding):
|
||||
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
||||
|
||||
|
||||
class MultiEmbedding(nn.Module):
|
||||
def __init__(self, num_embeddings, embedding_dim, n_levels):
|
||||
super().__init__()
|
||||
class AdditiveMultiEmbedding(nn.Embedding):
|
||||
"""
|
||||
This embedding sums embeddings from all levels.
|
||||
"""
|
||||
|
||||
def __init__(self, n_levels, n_tokens, token_dim):
|
||||
self.n_levels = n_levels
|
||||
self.num_embeddings = num_embeddings
|
||||
self.emb = nn.Embedding(n_levels * num_embeddings, embedding_dim)
|
||||
self.n_tokens = n_tokens
|
||||
super().__init__(n_levels * n_tokens, token_dim)
|
||||
|
||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
x = torch.cat(x_list)
|
||||
assert x.shape[1] == self.n_levels
|
||||
w = rearrange(self.emb.weight, "(q k) d -> q k d", q=self.n_levels)
|
||||
x = F.one_hot(x, num_classes=self.num_embeddings).float() # n q -> n q k
|
||||
w = rearrange(self.weight, "(q k) d -> q k d", q=self.n_levels)
|
||||
x = F.one_hot(x, num_classes=self.n_tokens).float() # n q -> n q k
|
||||
x = einsum("q k d, n q k -> n d", w, x)
|
||||
return x.split([*map(len, x_list)])
|
||||
x_list = x.split([*map(len, x_list)])
|
||||
return x_list
|
||||
|
||||
|
||||
class SelectiveMultiEmbedding(nn.Embedding):
|
||||
"""
|
||||
This embedding pick up the embedding at the certain level.
|
||||
"""
|
||||
|
||||
def __init__(self, n_levels, n_tokens_per_level, token_dim):
|
||||
self.n_tokens_per_level = n_tokens_per_level
|
||||
super().__init__(n_levels, n_tokens_per_level * token_dim)
|
||||
|
||||
def forward(self, x_list: list[Tensor], l: Tensor | None = None):
|
||||
"""
|
||||
Args:
|
||||
x_list: [(t)], tokens
|
||||
l: (b), levels, if none, pick the first
|
||||
"""
|
||||
x = pad_sequence(x_list, batch_first=True) # b t
|
||||
|
||||
if l is not None:
|
||||
w = super().forward(l) # b d
|
||||
else:
|
||||
w = repeat(self.weight[0], "d -> b d", b=len(x))
|
||||
|
||||
w = rearrange(w, "b (k d) -> b k d", k=self.n_tokens_per_level)
|
||||
x = F.one_hot(x, num_classes=self.n_tokens_per_level).float() # b t k
|
||||
x = einsum("b k d, b t k -> b t d", w, x)
|
||||
|
||||
x_list = [xi[:li] for xi, li in zip(x, map(len, x_list))]
|
||||
|
||||
return x_list
|
||||
|
||||
|
||||
def _join(x: tuple[Tensor], sep: Tensor):
|
||||
@ -224,13 +298,17 @@ class Base(nn.Module):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_levels(self) -> int:
|
||||
def n_resp_levels(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def use_stop_token(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def norm_type(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
return 8
|
||||
@ -247,9 +325,9 @@ class Base(nn.Module):
|
||||
super().__init__()
|
||||
self.n_tokens = n_tokens
|
||||
|
||||
n_levels = self.n_levels
|
||||
casual = self.casual
|
||||
|
||||
# +1 to include the stop token
|
||||
n_stop_tokens = 1 if self.use_stop_token else 0
|
||||
n_resp_tokens = n_tokens + n_stop_tokens
|
||||
|
||||
@ -257,20 +335,31 @@ class Base(nn.Module):
|
||||
|
||||
# It's not clear whether the whole prom are used or only the first level quantization
|
||||
# Just use all of them as it is more sufficient and we don't need to sample it, or do we?
|
||||
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=self.n_prom_levels)
|
||||
self.prom_emb = AdditiveMultiEmbedding(self.n_prom_levels, n_tokens, d_model)
|
||||
|
||||
# +1 to include the stop token
|
||||
# Note that, for different levels, I don't use AdaLN for simplicity
|
||||
# Use different embeddings might be enough.
|
||||
self.resp_embs = nn.ModuleList(
|
||||
[Embedding(n_resp_tokens, d_model) for _ in range(n_levels)]
|
||||
)
|
||||
if self.n_resp_levels:
|
||||
self.resp_emb = SelectiveMultiEmbedding(
|
||||
self.n_resp_levels, n_resp_tokens, d_model
|
||||
)
|
||||
|
||||
self.sin_emb = SinusodialEmbedding(d_model)
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
blocks = [Block(d_model, n_heads, p_dropout, casual) for _ in range(n_layers)]
|
||||
blocks = [
|
||||
Block(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
p_dropout=p_dropout,
|
||||
casual=casual,
|
||||
norm_type=self.norm_type,
|
||||
n_levels=self.n_resp_levels,
|
||||
)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
@ -302,7 +391,7 @@ class Base(nn.Module):
|
||||
proms_list: list[Tensor],
|
||||
resp_list: list[Tensor],
|
||||
targ_list: list[Tensor] | None = None,
|
||||
quant_level: int = 0,
|
||||
quant_levels: Tensor | None = None,
|
||||
shift_targ_list: bool = False,
|
||||
return_all_resp: Literal[False] = False,
|
||||
) -> Tensor:
|
||||
@ -315,7 +404,7 @@ class Base(nn.Module):
|
||||
proms_list: list[Tensor],
|
||||
resp_list: list[Tensor],
|
||||
targ_list: list[Tensor] | None = None,
|
||||
quant_level: int = 0,
|
||||
quant_levels: Tensor | None = None,
|
||||
shift_targ_list: bool = False,
|
||||
return_all_resp: Literal[True] = True,
|
||||
) -> list[Tensor]:
|
||||
@ -327,7 +416,7 @@ class Base(nn.Module):
|
||||
proms_list: list[Tensor],
|
||||
resp_list: list[Tensor],
|
||||
targ_list: list[Tensor] | None = None,
|
||||
quant_level: int = 0,
|
||||
quant_levels: Tensor | None = None,
|
||||
shift_targ_list: bool = False,
|
||||
return_all_resp: bool = False,
|
||||
):
|
||||
@ -337,7 +426,7 @@ class Base(nn.Module):
|
||||
proms_list: [t' k] * b
|
||||
resp_list: [t''] * b, one quantization level only
|
||||
targ_list: [t''] * b, one quantization level only, when given, loss will be computed
|
||||
quant_level: specify which quant_level to feed forward, used in NAR mode.
|
||||
quant_levels: specify which quant_levels to feed forward, used in NAR mode.
|
||||
shift_targ_list: whether to shift target list when computing loss. True if AR.
|
||||
return_all_resp: True if NAR.
|
||||
Returns:
|
||||
@ -346,7 +435,7 @@ class Base(nn.Module):
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.prom_emb(proms_list),
|
||||
self.resp_embs[quant_level](resp_list),
|
||||
self.resp_emb(resp_list, quant_levels),
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
@ -354,7 +443,7 @@ class Base(nn.Module):
|
||||
x = self.sin_emb.add_pe(x)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, m)
|
||||
x = block(x, m, quant_levels)
|
||||
|
||||
h = self.classifier(x) * m
|
||||
|
||||
|
@ -9,7 +9,7 @@ from .base import Base
|
||||
|
||||
class NAR(Base):
|
||||
@property
|
||||
def n_levels(self):
|
||||
def n_resp_levels(self):
|
||||
return 7
|
||||
|
||||
@property
|
||||
@ -20,6 +20,10 @@ class NAR(Base):
|
||||
def use_stop_token(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def norm_type(self):
|
||||
return "adaln"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
@ -44,68 +48,48 @@ class NAR(Base):
|
||||
|
||||
if resps_list is not None:
|
||||
levels = {r.shape[-1] for r in resps_list}
|
||||
if any(level != self.n_levels + 1 for level in levels):
|
||||
if any(level != self.n_resp_levels + 1 for level in levels):
|
||||
raise ValueError(
|
||||
f"resps_list should have exactly {self.n_levels + 1} levels, but got {levels}."
|
||||
f"resps_list should have exactly {self.n_resp_levels + 1} levels, but got {levels}."
|
||||
)
|
||||
|
||||
if resp_list is not None:
|
||||
device = text_list[0].device
|
||||
|
||||
if resp_list is None:
|
||||
assert resps_list is not None
|
||||
|
||||
quant_levels = torch.randint(0, self.n_resp_levels, (len(resps_list),))
|
||||
|
||||
curr_resp_list = [o[..., l] for o, l in zip(resps_list, quant_levels)]
|
||||
next_resp_list = [o[..., l + 1] for o, l in zip(resps_list, quant_levels)]
|
||||
|
||||
quant_levels = quant_levels.to(device=device)
|
||||
|
||||
_ = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
curr_resp_list,
|
||||
next_resp_list,
|
||||
return_all_resp=True,
|
||||
shift_targ_list=False,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
# Yes, just nothing as we are training
|
||||
hyp_resp_lists = []
|
||||
else:
|
||||
hyp_resp_lists = [resp_list]
|
||||
for i in range(self.n_levels):
|
||||
for level in range(self.n_resp_levels):
|
||||
quant_levels = torch.full((len(text_list),), level, device=device)
|
||||
hyp_resp_list = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
hyp_resp_lists[-1],
|
||||
return_all_resp=True,
|
||||
shift_targ_list=False,
|
||||
quant_level=i,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
hyp_resp_lists.append(hyp_resp_list)
|
||||
else:
|
||||
assert resps_list is not None
|
||||
|
||||
# I noticed that VALL-E randomly sample a layer,
|
||||
# which will be more memory efficient, let's do it.
|
||||
# For simplicity, do it on per batch level instead of per sample level
|
||||
# does that matter?
|
||||
|
||||
# Old code:
|
||||
# loss = {}
|
||||
# resp_list = [o[..., 0] for o in resps_list]
|
||||
# hyp_resp_lists = [resp_list]
|
||||
# for i in range(self.n_levels):
|
||||
# resp_list = [o[..., 0] for o in resps_list]
|
||||
# next_resp_list = [o[..., i + 1] for o in resps_list]
|
||||
# hyp_resp_list = super().forward(
|
||||
# text_list,
|
||||
# proms_list,
|
||||
# resp_list,
|
||||
# next_resp_list,
|
||||
# return_all_resp=True,
|
||||
# shift_targ_list=False,
|
||||
# quant_level=i,
|
||||
# )
|
||||
# hyp_resp_lists.append(hyp_resp_list)
|
||||
# loss |= {f"l{i}": self.loss}
|
||||
# del self.loss
|
||||
# self.loss = loss
|
||||
|
||||
quant_level = random.randint(0, self.n_levels - 1)
|
||||
cur_resp_list = [o[..., quant_level] for o in resps_list]
|
||||
next_resp_list = [o[..., quant_level + 1] for o in resps_list]
|
||||
|
||||
_ = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
cur_resp_list,
|
||||
next_resp_list,
|
||||
return_all_resp=True,
|
||||
shift_targ_list=False,
|
||||
quant_level=quant_level,
|
||||
)
|
||||
|
||||
# Yes, just nothing as we are training
|
||||
hyp_resp_lists = []
|
||||
|
||||
hyp_resps_list = [
|
||||
*map(lambda ts: torch.stack(ts, dim=-1), zip(*hyp_resp_lists))
|
||||
|
Loading…
Reference in New Issue
Block a user