vall-e/vall_e/models/transformer.py
2024-06-03 21:28:49 -05:00

217 lines
5.3 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# https://github.com/enhuiz/vall-e/
"""
import math
import torch
import torch.nn.functional as F
import traceback
from typing import Literal, overload
from functools import partial
from einops import rearrange
from torch import Tensor, einsum, nn
from torch.utils.checkpoint import checkpoint
from ..utils import wrapper as ml
class AdaLN(nn.Module):
def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):
super().__init__()
self.eps = eps
self.emb = nn.Embedding(n_levels, d_model * 2)
self.k = k
self.c = c
nn.init.zeros_(self.emb.weight)
def forward(self, x, l):
h = F.layer_norm(x, x.shape[-1:], eps=self.eps)
# The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135)
# performed worse than vanilla LayerNorm.
# The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN.
# Did they use AdaNorm inside AdaLN? (as follows)
h = self.c * (1 - (self.k * h).detach()) * h
logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1)
y = logγ.exp() * h + β
return y
class SinusoidalEmbedding(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
exponent = torch.arange(self.d_half, dtype=torch.float32)
exponent = exponent / self.d_half
omega = torch.exp(-math.log(1e4) * exponent)
self.omega: torch.Tensor
self.register_buffer("omega", omega, persistent=False)
@property
def d_half(self):
assert self.d_model % 2 == 0, "Only support even d_model."
return self.d_model // 2
def forward(self, x):
"""
Args:
x: (...)
Returns:
pe: (... d)
"""
omega = self.omega
while omega.dim() <= x.dim():
omega = omega.unsqueeze(0) # (... d)
x = x.unsqueeze(-1) # (... 1)
x = omega * x
x = torch.cat([x.sin(), x.cos()], dim=-1)
return x
def get_pe(self, n: int):
"""
Args:
n: int
Returns:
pe: (n d)
"""
device = self.omega.device
return self.forward(torch.arange(n, device=device))
def add_pe(self, x):
"""
Args:
x: (b t c)
"""
e = self.get_pe(x.shape[1]) # t d
e = e[None] # b t d
x = x + e
return x
class Attention(nn.Module):
def __init__(self, d_model, n_heads, causal):
super().__init__()
assert d_model % n_heads == 0
dim_head = d_model // n_heads
self.causal = causal
self.n_heads = n_heads
self.scale = dim_head**-0.5
self.to_qkv = ml.Linear(d_model, d_model * 3, bias=False)
self.to_out = ml.Linear(d_model, d_model)
def forward(self, x, m):
"""
Args:
x: (b t c)
m: (b t c), 1 is data, 0 is padding
Returns:
x: (b t c)
"""
h = self.n_heads
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b t (h d) -> b t h d", h=h), (q, k, v))
e = einsum("b i h d, b j h d -> b i j h", q, k)
e = e * self.scale
kpm = m.unsqueeze(1) * m.unsqueeze(2) # b i j 1
if self.causal:
with ml.autocast(kpm, torch.bfloat16, torch.float16) as k:
kpm = k.squeeze(-1).tril().unsqueeze(-1) # b i j 1
e = e.masked_fill(kpm == 0, -torch.finfo(e.dtype).max)
a = e.softmax(dim=2) # Normalize on j, i.e. key
o = einsum("b i j h, b j h d -> b i h d", a, v)
o = o.flatten(-2)
o = self.to_out(o) # b t c
o = o * m
return o
class PrenormResidual(nn.Module):
def __init__(
self,
block,
d_model,
p_dropout,
requires_mask=False,
norm_type="ln",
n_levels: int | None = None,
):
super().__init__()
self.block = block
self.requires_mask = requires_mask
self.norm_type = norm_type
if norm_type == "ln":
self.norm = nn.LayerNorm(d_model)
elif norm_type == "adaln":
assert n_levels is not None
self.norm = AdaLN(d_model, n_levels)
else:
raise NotImplementedError(norm_type)
self.dropout = nn.Dropout(p_dropout)
def forward(self, x, m, l):
"""
Args:
x: input (b t d)
m: mask (b t 1), 1 is valuable and 0 is padding
l: level to use, required only for AdaLN
"""
nopts = {"l": l} if self.norm_type == "adaln" else {}
bopts = {"m": m} if self.requires_mask else {}
x = x + self.dropout(self.block(self.norm(x, **nopts) * m, **bopts))
return x * m
class Block(nn.Sequential):
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels, activation_checkpointing=True):
super().__init__()
self.activation_checkpointing = activation_checkpointing
self.attn = PrenormResidual(
Attention(d_model, n_heads, causal),
d_model=d_model,
p_dropout=p_dropout,
requires_mask=True,
norm_type=norm_type,
n_levels=n_levels,
)
n_ff = d_model * 4 # 1024 * 4 = 4096 feed-forwards
self.ffn = PrenormResidual(
nn.Sequential(
ml.Linear(d_model, n_ff),
nn.GELU(),
nn.Dropout(p_dropout),
ml.Linear(n_ff, d_model),
),
d_model=d_model,
p_dropout=p_dropout,
norm_type=norm_type,
n_levels=n_levels,
)
def forward(self, x, m, l):
"""
Args:
x: (b t c)
m: (b t 1)
l: (b)
"""
if x.requires_grad and self.activation_checkpointing:
x = checkpoint(self.attn, x, m, l, use_reentrant=False)
else:
x = self.attn(x, m, l)
x = self.ffn(x, m, l)
return x