2023-08-02 21:53:35 +00:00
|
|
|
|
"""
|
|
|
|
|
# https://github.com/enhuiz/vall-e/
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
|
|
from typing import Literal, overload
|
|
|
|
|
from functools import partial
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
|
|
|
|
from torch import Tensor, einsum, nn
|
|
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
|
|
|
|
from ..utils import wrapper as ml
|
2024-06-04 02:28:49 +00:00
|
|
|
|
|
|
|
|
|
class AdaLN(nn.Module):
|
|
|
|
|
def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.emb = nn.Embedding(n_levels, d_model * 2)
|
|
|
|
|
self.k = k
|
|
|
|
|
self.c = c
|
|
|
|
|
nn.init.zeros_(self.emb.weight)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, l):
|
|
|
|
|
h = F.layer_norm(x, x.shape[-1:], eps=self.eps)
|
|
|
|
|
|
|
|
|
|
# The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135)
|
|
|
|
|
# performed worse than vanilla LayerNorm.
|
|
|
|
|
# The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN.
|
|
|
|
|
# Did they use AdaNorm inside AdaLN? (as follows)
|
|
|
|
|
h = self.c * (1 - (self.k * h).detach()) * h
|
|
|
|
|
|
|
|
|
|
logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1)
|
|
|
|
|
y = logγ.exp() * h + β
|
|
|
|
|
|
|
|
|
|
return y
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
|
|
class SinusoidalEmbedding(nn.Module):
|
|
|
|
|
def __init__(self, d_model):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.d_model = d_model
|
|
|
|
|
exponent = torch.arange(self.d_half, dtype=torch.float32)
|
|
|
|
|
exponent = exponent / self.d_half
|
|
|
|
|
omega = torch.exp(-math.log(1e4) * exponent)
|
|
|
|
|
self.omega: torch.Tensor
|
|
|
|
|
self.register_buffer("omega", omega, persistent=False)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def d_half(self):
|
|
|
|
|
assert self.d_model % 2 == 0, "Only support even d_model."
|
|
|
|
|
return self.d_model // 2
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
x: (...)
|
|
|
|
|
Returns:
|
|
|
|
|
pe: (... d)
|
|
|
|
|
"""
|
|
|
|
|
omega = self.omega
|
|
|
|
|
|
|
|
|
|
while omega.dim() <= x.dim():
|
|
|
|
|
omega = omega.unsqueeze(0) # (... d)
|
|
|
|
|
|
|
|
|
|
x = x.unsqueeze(-1) # (... 1)
|
|
|
|
|
x = omega * x
|
|
|
|
|
x = torch.cat([x.sin(), x.cos()], dim=-1)
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def get_pe(self, n: int):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
n: int
|
|
|
|
|
Returns:
|
|
|
|
|
pe: (n d)
|
|
|
|
|
"""
|
|
|
|
|
device = self.omega.device
|
|
|
|
|
return self.forward(torch.arange(n, device=device))
|
|
|
|
|
|
|
|
|
|
def add_pe(self, x):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
x: (b t c)
|
|
|
|
|
"""
|
|
|
|
|
e = self.get_pe(x.shape[1]) # t d
|
|
|
|
|
e = e[None] # b t d
|
|
|
|
|
x = x + e
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
|
def __init__(self, d_model, n_heads, causal):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert d_model % n_heads == 0
|
|
|
|
|
dim_head = d_model // n_heads
|
|
|
|
|
self.causal = causal
|
|
|
|
|
self.n_heads = n_heads
|
|
|
|
|
self.scale = dim_head**-0.5
|
|
|
|
|
|
2024-03-02 01:20:10 +00:00
|
|
|
|
self.to_qkv = ml.Linear(d_model, d_model * 3, bias=False)
|
|
|
|
|
self.to_out = ml.Linear(d_model, d_model)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
|
|
def forward(self, x, m):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
x: (b t c)
|
|
|
|
|
m: (b t c), 1 is data, 0 is padding
|
|
|
|
|
Returns:
|
|
|
|
|
x: (b t c)
|
|
|
|
|
"""
|
|
|
|
|
h = self.n_heads
|
|
|
|
|
|
|
|
|
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, "b t (h d) -> b t h d", h=h), (q, k, v))
|
|
|
|
|
|
|
|
|
|
e = einsum("b i h d, b j h d -> b i j h", q, k)
|
|
|
|
|
e = e * self.scale
|
|
|
|
|
|
|
|
|
|
kpm = m.unsqueeze(1) * m.unsqueeze(2) # b i j 1
|
|
|
|
|
|
|
|
|
|
if self.causal:
|
|
|
|
|
with ml.autocast(kpm, torch.bfloat16, torch.float16) as k:
|
|
|
|
|
kpm = k.squeeze(-1).tril().unsqueeze(-1) # b i j 1
|
|
|
|
|
|
|
|
|
|
e = e.masked_fill(kpm == 0, -torch.finfo(e.dtype).max)
|
|
|
|
|
a = e.softmax(dim=2) # Normalize on j, i.e. key
|
|
|
|
|
|
|
|
|
|
o = einsum("b i j h, b j h d -> b i h d", a, v)
|
|
|
|
|
o = o.flatten(-2)
|
|
|
|
|
o = self.to_out(o) # b t c
|
|
|
|
|
|
|
|
|
|
o = o * m
|
|
|
|
|
|
|
|
|
|
return o
|
|
|
|
|
|
|
|
|
|
class PrenormResidual(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
block,
|
|
|
|
|
d_model,
|
|
|
|
|
p_dropout,
|
|
|
|
|
requires_mask=False,
|
|
|
|
|
norm_type="ln",
|
|
|
|
|
n_levels: int | None = None,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.block = block
|
|
|
|
|
self.requires_mask = requires_mask
|
|
|
|
|
self.norm_type = norm_type
|
|
|
|
|
if norm_type == "ln":
|
|
|
|
|
self.norm = nn.LayerNorm(d_model)
|
|
|
|
|
elif norm_type == "adaln":
|
|
|
|
|
assert n_levels is not None
|
|
|
|
|
self.norm = AdaLN(d_model, n_levels)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(norm_type)
|
|
|
|
|
self.dropout = nn.Dropout(p_dropout)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, m, l):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
x: input (b t d)
|
|
|
|
|
m: mask (b t 1), 1 is valuable and 0 is padding
|
|
|
|
|
l: level to use, required only for AdaLN
|
|
|
|
|
"""
|
|
|
|
|
nopts = {"l": l} if self.norm_type == "adaln" else {}
|
|
|
|
|
bopts = {"m": m} if self.requires_mask else {}
|
|
|
|
|
x = x + self.dropout(self.block(self.norm(x, **nopts) * m, **bopts))
|
|
|
|
|
return x * m
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Block(nn.Sequential):
|
2023-09-05 20:38:21 +00:00
|
|
|
|
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels, activation_checkpointing=True):
|
2023-08-02 21:53:35 +00:00
|
|
|
|
super().__init__()
|
|
|
|
|
|
2023-09-05 20:38:21 +00:00
|
|
|
|
self.activation_checkpointing = activation_checkpointing
|
2023-08-02 21:53:35 +00:00
|
|
|
|
self.attn = PrenormResidual(
|
|
|
|
|
Attention(d_model, n_heads, causal),
|
|
|
|
|
d_model=d_model,
|
|
|
|
|
p_dropout=p_dropout,
|
|
|
|
|
requires_mask=True,
|
|
|
|
|
norm_type=norm_type,
|
|
|
|
|
n_levels=n_levels,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
n_ff = d_model * 4 # 1024 * 4 = 4096 feed-forwards
|
|
|
|
|
self.ffn = PrenormResidual(
|
|
|
|
|
nn.Sequential(
|
2024-03-02 01:20:10 +00:00
|
|
|
|
ml.Linear(d_model, n_ff),
|
2023-08-02 21:53:35 +00:00
|
|
|
|
nn.GELU(),
|
|
|
|
|
nn.Dropout(p_dropout),
|
2024-03-02 01:20:10 +00:00
|
|
|
|
ml.Linear(n_ff, d_model),
|
2023-08-02 21:53:35 +00:00
|
|
|
|
),
|
|
|
|
|
d_model=d_model,
|
|
|
|
|
p_dropout=p_dropout,
|
|
|
|
|
norm_type=norm_type,
|
|
|
|
|
n_levels=n_levels,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, m, l):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
x: (b t c)
|
|
|
|
|
m: (b t 1)
|
|
|
|
|
l: (b)
|
|
|
|
|
"""
|
2023-09-05 20:38:21 +00:00
|
|
|
|
if x.requires_grad and self.activation_checkpointing:
|
2023-08-02 21:53:35 +00:00
|
|
|
|
x = checkpoint(self.attn, x, m, l, use_reentrant=False)
|
|
|
|
|
else:
|
|
|
|
|
x = self.attn(x, m, l)
|
|
|
|
|
x = self.ffn(x, m, l)
|
|
|
|
|
return x
|