vall-e/vall_e/models/arch/transformer.py

217 lines
5.3 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
"""
# https://github.com/enhuiz/vall-e/
"""
import math
import torch
import torch.nn.functional as F
import traceback
from typing import Literal, overload
from functools import partial
from einops import rearrange
from torch import Tensor, einsum, nn
from torch.utils.checkpoint import checkpoint
2024-06-06 01:30:43 +00:00
from ...utils import wrapper as ml
2024-06-04 02:28:49 +00:00
class AdaLN(nn.Module):
def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):
super().__init__()
self.eps = eps
self.emb = nn.Embedding(n_levels, d_model * 2)
self.k = k
self.c = c
nn.init.zeros_(self.emb.weight)
def forward(self, x, l):
h = F.layer_norm(x, x.shape[-1:], eps=self.eps)
# The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135)
# performed worse than vanilla LayerNorm.
# The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN.
# Did they use AdaNorm inside AdaLN? (as follows)
h = self.c * (1 - (self.k * h).detach()) * h
logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1)
y = logγ.exp() * h + β
return y
2023-08-02 21:53:35 +00:00
class SinusoidalEmbedding(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
exponent = torch.arange(self.d_half, dtype=torch.float32)
exponent = exponent / self.d_half
omega = torch.exp(-math.log(1e4) * exponent)
self.omega: torch.Tensor
self.register_buffer("omega", omega, persistent=False)
@property
def d_half(self):
assert self.d_model % 2 == 0, "Only support even d_model."
return self.d_model // 2
def forward(self, x):
"""
Args:
x: (...)
Returns:
pe: (... d)
"""
omega = self.omega
while omega.dim() <= x.dim():
omega = omega.unsqueeze(0) # (... d)
x = x.unsqueeze(-1) # (... 1)
x = omega * x
x = torch.cat([x.sin(), x.cos()], dim=-1)
return x
def get_pe(self, n: int):
"""
Args:
n: int
Returns:
pe: (n d)
"""
device = self.omega.device
return self.forward(torch.arange(n, device=device))
def add_pe(self, x):
"""
Args:
x: (b t c)
"""
e = self.get_pe(x.shape[1]) # t d
e = e[None] # b t d
x = x + e
return x
class Attention(nn.Module):
def __init__(self, d_model, n_heads, causal):
super().__init__()
assert d_model % n_heads == 0
dim_head = d_model // n_heads
self.causal = causal
self.n_heads = n_heads
self.scale = dim_head**-0.5
self.to_qkv = ml.Linear(d_model, d_model * 3, bias=False)
self.to_out = ml.Linear(d_model, d_model)
2023-08-02 21:53:35 +00:00
def forward(self, x, m):
"""
Args:
x: (b t c)
m: (b t c), 1 is data, 0 is padding
Returns:
x: (b t c)
"""
h = self.n_heads
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b t (h d) -> b t h d", h=h), (q, k, v))
e = einsum("b i h d, b j h d -> b i j h", q, k)
e = e * self.scale
kpm = m.unsqueeze(1) * m.unsqueeze(2) # b i j 1
if self.causal:
with ml.autocast(kpm, torch.bfloat16, torch.float16) as k:
kpm = k.squeeze(-1).tril().unsqueeze(-1) # b i j 1
e = e.masked_fill(kpm == 0, -torch.finfo(e.dtype).max)
a = e.softmax(dim=2) # Normalize on j, i.e. key
o = einsum("b i j h, b j h d -> b i h d", a, v)
o = o.flatten(-2)
o = self.to_out(o) # b t c
o = o * m
return o
class PrenormResidual(nn.Module):
def __init__(
self,
block,
d_model,
p_dropout,
requires_mask=False,
norm_type="ln",
n_levels: int | None = None,
):
super().__init__()
self.block = block
self.requires_mask = requires_mask
self.norm_type = norm_type
if norm_type == "ln":
self.norm = nn.LayerNorm(d_model)
elif norm_type == "adaln":
assert n_levels is not None
self.norm = AdaLN(d_model, n_levels)
else:
raise NotImplementedError(norm_type)
self.dropout = nn.Dropout(p_dropout)
def forward(self, x, m, l):
"""
Args:
x: input (b t d)
m: mask (b t 1), 1 is valuable and 0 is padding
l: level to use, required only for AdaLN
"""
nopts = {"l": l} if self.norm_type == "adaln" else {}
bopts = {"m": m} if self.requires_mask else {}
x = x + self.dropout(self.block(self.norm(x, **nopts) * m, **bopts))
return x * m
class Block(nn.Sequential):
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels, activation_checkpointing=True):
2023-08-02 21:53:35 +00:00
super().__init__()
self.activation_checkpointing = activation_checkpointing
2023-08-02 21:53:35 +00:00
self.attn = PrenormResidual(
Attention(d_model, n_heads, causal),
d_model=d_model,
p_dropout=p_dropout,
requires_mask=True,
norm_type=norm_type,
n_levels=n_levels,
)
n_ff = d_model * 4 # 1024 * 4 = 4096 feed-forwards
self.ffn = PrenormResidual(
nn.Sequential(
ml.Linear(d_model, n_ff),
2023-08-02 21:53:35 +00:00
nn.GELU(),
nn.Dropout(p_dropout),
ml.Linear(n_ff, d_model),
2023-08-02 21:53:35 +00:00
),
d_model=d_model,
p_dropout=p_dropout,
norm_type=norm_type,
n_levels=n_levels,
)
def forward(self, x, m, l):
"""
Args:
x: (b t c)
m: (b t 1)
l: (b)
"""
if x.requires_grad and self.activation_checkpointing:
2023-08-02 21:53:35 +00:00
x = checkpoint(self.attn, x, m, l, use_reentrant=False)
else:
x = self.attn(x, m, l)
x = self.ffn(x, m, l)
return x