AdaLN and sample-wise quant level sampling

This commit is contained in:
enhuiz 2023-01-13 00:32:34 +08:00
parent bead906d72
commit 36e8894f5c
8 changed files with 179 additions and 100 deletions

View File

@ -42,4 +42,6 @@ python -m vall_e.train yaml=config/your_data/ar_or_nar.yml
- [x] Audio decoding from tokens - [x] Audio decoding from tokens
- [x] NAR model for the rest quantizers - [x] NAR model for the rest quantizers
- [x] Trainers for both models - [x] Trainers for both models
- [x] Implement AdaLN for NAR model.
- [x] Sample-wise quantization level sampling for NAR training.
- [ ] Pre-trained checkpoint and demos on LibriTTS - [ ] Pre-trained checkpoint and demos on LibriTTS

View File

@ -4,3 +4,4 @@ spkr_name_getter: "lambda p: p.parts[-3]"
model: ar-quarter model: ar-quarter
batch_size: 8 batch_size: 8
eval_batch_size: 8 eval_batch_size: 8
eval_every: 10_000

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -8,7 +8,7 @@ from .base import Base
class AR(Base): class AR(Base):
@property @property
def n_levels(self): def n_resp_levels(self):
return 1 return 1
@property @property
@ -19,6 +19,10 @@ class AR(Base):
def use_stop_token(self): def use_stop_token(self):
return True return True
@property
def norm_type(self):
return "ln"
def _prune(self, l: Tensor): def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero() indices = (l == self.stop_token).nonzero()
if len(indices) == 0: if len(indices) == 0:
@ -38,7 +42,7 @@ class AR(Base):
proms_list, proms_list,
resp_list, resp_list,
resp_list, resp_list,
quant_level=0, quant_levels=None,
shift_targ_list=True, shift_targ_list=True,
return_all_resp=False, return_all_resp=False,
) )
@ -55,14 +59,13 @@ class AR(Base):
resp_list: list[Tensor] = [ resp_list: list[Tensor] = [
torch.zeros(0, device=device).long() for _ in text_list torch.zeros(0, device=device).long() for _ in text_list
] ]
stopped = [False] * len(text_list) stopped = torch.zeros(len(text_list), device=device).bool()
for _ in trange(max_steps): for _ in trange(max_steps):
r = super().forward(text_list, proms_list, resp_list) r = super().forward(text_list, proms_list, resp_list)
stopped |= r == self.stop_token
for i, ri in enumerate(r): for i, ri in enumerate(r):
if ri.item() == self.stop_token:
stopped[i] = True
resp_list[i] = torch.cat([resp_list[i], ri[None]]) resp_list[i] = torch.cat([resp_list[i], ri[None]])
if all(stopped): if stopped.all().item():
break break
pruned = [self._prune(r) for r in resp_list] pruned = [self._prune(r) for r in resp_list]
return pruned return pruned

View File

@ -4,7 +4,7 @@ from typing import Literal, overload
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange, repeat
from torch import Tensor, einsum, nn from torch import Tensor, einsum, nn
from torch.distributions import Categorical from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
@ -89,12 +89,12 @@ class SinusodialEmbedding(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, d_model, num_heads, casual): def __init__(self, d_model, n_heads, casual):
super().__init__() super().__init__()
assert d_model % num_heads == 0 assert d_model % n_heads == 0
dim_head = d_model // num_heads dim_head = d_model // n_heads
self.casual = casual self.casual = casual
self.num_heads = num_heads self.n_heads = n_heads
self.scale = dim_head**-0.5 self.scale = dim_head**-0.5
self.to_qkv = nn.Linear(d_model, d_model * 3, bias=False) self.to_qkv = nn.Linear(d_model, d_model * 3, bias=False)
self.to_out = nn.Linear(d_model, d_model) self.to_out = nn.Linear(d_model, d_model)
@ -107,7 +107,7 @@ class Attention(nn.Module):
Returns: Returns:
x: (b t c) x: (b t c)
""" """
h = self.num_heads h = self.n_heads
q, k, v = self.to_qkv(x).chunk(3, dim=-1) q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b t (h d) -> b t h d", h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, "b t (h d) -> b t h d", h=h), (q, k, v))
@ -132,52 +132,91 @@ class Attention(nn.Module):
return o return o
class AdaLN(nn.Module):
def __init__(self, d_model, n_levels, eps=1e-5):
super().__init__()
self.eps = eps
self.emb = nn.Embedding(n_levels, d_model * 2)
nn.init.zeros_(self.emb.weight)
def forward(self, x, l):
logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1)
h = F.layer_norm(x, x.shape[-1:], eps=self.eps)
y = logγ.exp() * h + β
return y
class PrenormResidual(nn.Module): class PrenormResidual(nn.Module):
def __init__(self, block, d_model, dropout, requires_mask=False): def __init__(
self,
block,
d_model,
p_dropout,
requires_mask=False,
norm_type="ln",
n_levels: int | None = None,
):
super().__init__() super().__init__()
self.block = block self.block = block
self.requires_mask = requires_mask self.requires_mask = requires_mask
self.norm = nn.LayerNorm(d_model) self.norm_type = norm_type
self.dropout = nn.Dropout(dropout) if norm_type == "ln":
self.norm = nn.LayerNorm(d_model)
elif norm_type == "adaln":
assert n_levels is not None
self.norm = AdaLN(d_model, n_levels)
self.dropout = nn.Dropout(p_dropout)
def forward(self, x, m): def forward(self, x, m, l):
opts = {"m": m} if self.requires_mask else {} """
x = x + self.dropout(self.block(self.norm(x) * m, **opts)) Args:
x: input (b t d)
m: mask (b t 1), 1 is valuable and 0 is padding
l: level to use, required only for AdaLN
"""
nopts = {"l": l} if self.norm_type == "adaln" else {}
bopts = {"m": m} if self.requires_mask else {}
x = x + self.dropout(self.block(self.norm(x, **nopts) * m, **bopts))
return x * m return x * m
class Block(nn.Sequential): class Block(nn.Sequential):
def __init__(self, d_model, num_heads, dropout, casual): def __init__(self, d_model, n_heads, p_dropout, casual, norm_type, n_levels):
super().__init__() super().__init__()
self.attn = PrenormResidual( self.attn = PrenormResidual(
Attention(d_model, num_heads, casual), Attention(d_model, n_heads, casual),
d_model=d_model, d_model=d_model,
dropout=dropout, p_dropout=p_dropout,
requires_mask=True, requires_mask=True,
norm_type=norm_type,
n_levels=n_levels,
) )
self.ffn = PrenormResidual( self.ffn = PrenormResidual(
nn.Sequential( nn.Sequential(
nn.Linear(d_model, d_model * 4), nn.Linear(d_model, d_model * 4),
nn.GELU(), nn.GELU(),
nn.Dropout(dropout), nn.Dropout(p_dropout),
nn.Linear(d_model * 4, d_model), nn.Linear(d_model * 4, d_model),
), ),
d_model=d_model, d_model=d_model,
dropout=dropout, p_dropout=p_dropout,
norm_type=norm_type,
n_levels=n_levels,
) )
def forward(self, x, m): def forward(self, x, m, l):
""" """
Args: Args:
x: (b t c) x: (b t c)
m: (b t 1) m: (b t 1)
l: (b)
""" """
poor_in_vram = True poor_in_vram = True
if x.requires_grad and poor_in_vram: if x.requires_grad and poor_in_vram:
x = checkpoint(self.attn, x, m) x = checkpoint(self.attn, x, m, l)
else: else:
x = self.attn(x, m) x = self.attn(x, m, l)
x = self.ffn(x, m) x = self.ffn(x, m, l)
return x return x
@ -188,22 +227,57 @@ class Embedding(nn.Embedding):
return super().forward(torch.cat(x_list)).split([*map(len, x_list)]) return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
class MultiEmbedding(nn.Module): class AdditiveMultiEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, n_levels): """
super().__init__() This embedding sums embeddings from all levels.
"""
def __init__(self, n_levels, n_tokens, token_dim):
self.n_levels = n_levels self.n_levels = n_levels
self.num_embeddings = num_embeddings self.n_tokens = n_tokens
self.emb = nn.Embedding(n_levels * num_embeddings, embedding_dim) super().__init__(n_levels * n_tokens, token_dim)
def forward(self, x_list: list[Tensor]) -> list[Tensor]: def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0: if len(x_list) == 0:
return [] return []
x = torch.cat(x_list) x = torch.cat(x_list)
assert x.shape[1] == self.n_levels assert x.shape[1] == self.n_levels
w = rearrange(self.emb.weight, "(q k) d -> q k d", q=self.n_levels) w = rearrange(self.weight, "(q k) d -> q k d", q=self.n_levels)
x = F.one_hot(x, num_classes=self.num_embeddings).float() # n q -> n q k x = F.one_hot(x, num_classes=self.n_tokens).float() # n q -> n q k
x = einsum("q k d, n q k -> n d", w, x) x = einsum("q k d, n q k -> n d", w, x)
return x.split([*map(len, x_list)]) x_list = x.split([*map(len, x_list)])
return x_list
class SelectiveMultiEmbedding(nn.Embedding):
"""
This embedding pick up the embedding at the certain level.
"""
def __init__(self, n_levels, n_tokens_per_level, token_dim):
self.n_tokens_per_level = n_tokens_per_level
super().__init__(n_levels, n_tokens_per_level * token_dim)
def forward(self, x_list: list[Tensor], l: Tensor | None = None):
"""
Args:
x_list: [(t)], tokens
l: (b), levels, if none, pick the first
"""
x = pad_sequence(x_list, batch_first=True) # b t
if l is not None:
w = super().forward(l) # b d
else:
w = repeat(self.weight[0], "d -> b d", b=len(x))
w = rearrange(w, "b (k d) -> b k d", k=self.n_tokens_per_level)
x = F.one_hot(x, num_classes=self.n_tokens_per_level).float() # b t k
x = einsum("b k d, b t k -> b t d", w, x)
x_list = [xi[:li] for xi, li in zip(x, map(len, x_list))]
return x_list
def _join(x: tuple[Tensor], sep: Tensor): def _join(x: tuple[Tensor], sep: Tensor):
@ -224,13 +298,17 @@ class Base(nn.Module):
raise NotImplementedError raise NotImplementedError
@property @property
def n_levels(self) -> int: def n_resp_levels(self) -> int:
raise NotImplementedError raise NotImplementedError
@property @property
def use_stop_token(self) -> bool: def use_stop_token(self) -> bool:
raise NotImplementedError raise NotImplementedError
@property
def norm_type(self):
raise NotImplementedError
@property @property
def n_prom_levels(self) -> int: def n_prom_levels(self) -> int:
return 8 return 8
@ -247,9 +325,9 @@ class Base(nn.Module):
super().__init__() super().__init__()
self.n_tokens = n_tokens self.n_tokens = n_tokens
n_levels = self.n_levels
casual = self.casual casual = self.casual
# +1 to include the stop token
n_stop_tokens = 1 if self.use_stop_token else 0 n_stop_tokens = 1 if self.use_stop_token else 0
n_resp_tokens = n_tokens + n_stop_tokens n_resp_tokens = n_tokens + n_stop_tokens
@ -257,20 +335,31 @@ class Base(nn.Module):
# It's not clear whether the whole prom are used or only the first level quantization # It's not clear whether the whole prom are used or only the first level quantization
# Just use all of them as it is more sufficient and we don't need to sample it, or do we? # Just use all of them as it is more sufficient and we don't need to sample it, or do we?
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=self.n_prom_levels) self.prom_emb = AdditiveMultiEmbedding(self.n_prom_levels, n_tokens, d_model)
# +1 to include the stop token
# Note that, for different levels, I don't use AdaLN for simplicity # Note that, for different levels, I don't use AdaLN for simplicity
# Use different embeddings might be enough. # Use different embeddings might be enough.
self.resp_embs = nn.ModuleList( if self.n_resp_levels:
[Embedding(n_resp_tokens, d_model) for _ in range(n_levels)] self.resp_emb = SelectiveMultiEmbedding(
) self.n_resp_levels, n_resp_tokens, d_model
)
self.sin_emb = SinusodialEmbedding(d_model) self.sin_emb = SinusodialEmbedding(d_model)
self.sep = nn.Parameter(torch.randn(d_model)) self.sep = nn.Parameter(torch.randn(d_model))
blocks = [Block(d_model, n_heads, p_dropout, casual) for _ in range(n_layers)] blocks = [
Block(
d_model=d_model,
n_heads=n_heads,
p_dropout=p_dropout,
casual=casual,
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
)
for _ in range(n_layers)
]
self.blocks = nn.ModuleList(blocks) self.blocks = nn.ModuleList(blocks)
self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -302,7 +391,7 @@ class Base(nn.Module):
proms_list: list[Tensor], proms_list: list[Tensor],
resp_list: list[Tensor], resp_list: list[Tensor],
targ_list: list[Tensor] | None = None, targ_list: list[Tensor] | None = None,
quant_level: int = 0, quant_levels: Tensor | None = None,
shift_targ_list: bool = False, shift_targ_list: bool = False,
return_all_resp: Literal[False] = False, return_all_resp: Literal[False] = False,
) -> Tensor: ) -> Tensor:
@ -315,7 +404,7 @@ class Base(nn.Module):
proms_list: list[Tensor], proms_list: list[Tensor],
resp_list: list[Tensor], resp_list: list[Tensor],
targ_list: list[Tensor] | None = None, targ_list: list[Tensor] | None = None,
quant_level: int = 0, quant_levels: Tensor | None = None,
shift_targ_list: bool = False, shift_targ_list: bool = False,
return_all_resp: Literal[True] = True, return_all_resp: Literal[True] = True,
) -> list[Tensor]: ) -> list[Tensor]:
@ -327,7 +416,7 @@ class Base(nn.Module):
proms_list: list[Tensor], proms_list: list[Tensor],
resp_list: list[Tensor], resp_list: list[Tensor],
targ_list: list[Tensor] | None = None, targ_list: list[Tensor] | None = None,
quant_level: int = 0, quant_levels: Tensor | None = None,
shift_targ_list: bool = False, shift_targ_list: bool = False,
return_all_resp: bool = False, return_all_resp: bool = False,
): ):
@ -337,7 +426,7 @@ class Base(nn.Module):
proms_list: [t' k] * b proms_list: [t' k] * b
resp_list: [t''] * b, one quantization level only resp_list: [t''] * b, one quantization level only
targ_list: [t''] * b, one quantization level only, when given, loss will be computed targ_list: [t''] * b, one quantization level only, when given, loss will be computed
quant_level: specify which quant_level to feed forward, used in NAR mode. quant_levels: specify which quant_levels to feed forward, used in NAR mode.
shift_targ_list: whether to shift target list when computing loss. True if AR. shift_targ_list: whether to shift target list when computing loss. True if AR.
return_all_resp: True if NAR. return_all_resp: True if NAR.
Returns: Returns:
@ -346,7 +435,7 @@ class Base(nn.Module):
x_list = self._samplewise_merge_tensors( x_list = self._samplewise_merge_tensors(
self.text_emb(text_list), self.text_emb(text_list),
self.prom_emb(proms_list), self.prom_emb(proms_list),
self.resp_embs[quant_level](resp_list), self.resp_emb(resp_list, quant_levels),
sep=self.sep, sep=self.sep,
) )
@ -354,7 +443,7 @@ class Base(nn.Module):
x = self.sin_emb.add_pe(x) x = self.sin_emb.add_pe(x)
for block in self.blocks: for block in self.blocks:
x = block(x, m) x = block(x, m, quant_levels)
h = self.classifier(x) * m h = self.classifier(x) * m

View File

@ -9,7 +9,7 @@ from .base import Base
class NAR(Base): class NAR(Base):
@property @property
def n_levels(self): def n_resp_levels(self):
return 7 return 7
@property @property
@ -20,6 +20,10 @@ class NAR(Base):
def use_stop_token(self): def use_stop_token(self):
return False return False
@property
def norm_type(self):
return "adaln"
def forward( def forward(
self, self,
text_list: list[Tensor], text_list: list[Tensor],
@ -44,68 +48,48 @@ class NAR(Base):
if resps_list is not None: if resps_list is not None:
levels = {r.shape[-1] for r in resps_list} levels = {r.shape[-1] for r in resps_list}
if any(level != self.n_levels + 1 for level in levels): if any(level != self.n_resp_levels + 1 for level in levels):
raise ValueError( raise ValueError(
f"resps_list should have exactly {self.n_levels + 1} levels, but got {levels}." f"resps_list should have exactly {self.n_resp_levels + 1} levels, but got {levels}."
) )
if resp_list is not None: device = text_list[0].device
if resp_list is None:
assert resps_list is not None
quant_levels = torch.randint(0, self.n_resp_levels, (len(resps_list),))
curr_resp_list = [o[..., l] for o, l in zip(resps_list, quant_levels)]
next_resp_list = [o[..., l + 1] for o, l in zip(resps_list, quant_levels)]
quant_levels = quant_levels.to(device=device)
_ = super().forward(
text_list,
proms_list,
curr_resp_list,
next_resp_list,
return_all_resp=True,
shift_targ_list=False,
quant_levels=quant_levels,
)
# Yes, just nothing as we are training
hyp_resp_lists = []
else:
hyp_resp_lists = [resp_list] hyp_resp_lists = [resp_list]
for i in range(self.n_levels): for level in range(self.n_resp_levels):
quant_levels = torch.full((len(text_list),), level, device=device)
hyp_resp_list = super().forward( hyp_resp_list = super().forward(
text_list, text_list,
proms_list, proms_list,
hyp_resp_lists[-1], hyp_resp_lists[-1],
return_all_resp=True, return_all_resp=True,
shift_targ_list=False, shift_targ_list=False,
quant_level=i, quant_levels=quant_levels,
) )
hyp_resp_lists.append(hyp_resp_list) hyp_resp_lists.append(hyp_resp_list)
else:
assert resps_list is not None
# I noticed that VALL-E randomly sample a layer,
# which will be more memory efficient, let's do it.
# For simplicity, do it on per batch level instead of per sample level
# does that matter?
# Old code:
# loss = {}
# resp_list = [o[..., 0] for o in resps_list]
# hyp_resp_lists = [resp_list]
# for i in range(self.n_levels):
# resp_list = [o[..., 0] for o in resps_list]
# next_resp_list = [o[..., i + 1] for o in resps_list]
# hyp_resp_list = super().forward(
# text_list,
# proms_list,
# resp_list,
# next_resp_list,
# return_all_resp=True,
# shift_targ_list=False,
# quant_level=i,
# )
# hyp_resp_lists.append(hyp_resp_list)
# loss |= {f"l{i}": self.loss}
# del self.loss
# self.loss = loss
quant_level = random.randint(0, self.n_levels - 1)
cur_resp_list = [o[..., quant_level] for o in resps_list]
next_resp_list = [o[..., quant_level + 1] for o in resps_list]
_ = super().forward(
text_list,
proms_list,
cur_resp_list,
next_resp_list,
return_all_resp=True,
shift_targ_list=False,
quant_level=quant_level,
)
# Yes, just nothing as we are training
hyp_resp_lists = []
hyp_resps_list = [ hyp_resps_list = [
*map(lambda ts: torch.stack(ts, dim=-1), zip(*hyp_resp_lists)) *map(lambda ts: torch.stack(ts, dim=-1), zip(*hyp_resp_lists))