""" # https://github.com/enhuiz/vall-e/ """ import math import torch import torch.nn.functional as F import traceback from typing import Literal, overload from functools import partial from einops import rearrange from torch import Tensor, einsum, nn from torch.utils.checkpoint import checkpoint from ..utils import wrapper as ml class AdaLN(nn.Module): def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2): super().__init__() self.eps = eps self.emb = nn.Embedding(n_levels, d_model * 2) self.k = k self.c = c nn.init.zeros_(self.emb.weight) def forward(self, x, l): h = F.layer_norm(x, x.shape[-1:], eps=self.eps) # The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135) # performed worse than vanilla LayerNorm. # The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN. # Did they use AdaNorm inside AdaLN? (as follows) h = self.c * (1 - (self.k * h).detach()) * h logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1) y = logγ.exp() * h + β return y class SinusoidalEmbedding(nn.Module): def __init__(self, d_model): super().__init__() self.d_model = d_model exponent = torch.arange(self.d_half, dtype=torch.float32) exponent = exponent / self.d_half omega = torch.exp(-math.log(1e4) * exponent) self.omega: torch.Tensor self.register_buffer("omega", omega, persistent=False) @property def d_half(self): assert self.d_model % 2 == 0, "Only support even d_model." return self.d_model // 2 def forward(self, x): """ Args: x: (...) Returns: pe: (... d) """ omega = self.omega while omega.dim() <= x.dim(): omega = omega.unsqueeze(0) # (... d) x = x.unsqueeze(-1) # (... 1) x = omega * x x = torch.cat([x.sin(), x.cos()], dim=-1) return x def get_pe(self, n: int): """ Args: n: int Returns: pe: (n d) """ device = self.omega.device return self.forward(torch.arange(n, device=device)) def add_pe(self, x): """ Args: x: (b t c) """ e = self.get_pe(x.shape[1]) # t d e = e[None] # b t d x = x + e return x class Attention(nn.Module): def __init__(self, d_model, n_heads, causal): super().__init__() assert d_model % n_heads == 0 dim_head = d_model // n_heads self.causal = causal self.n_heads = n_heads self.scale = dim_head**-0.5 self.to_qkv = ml.Linear(d_model, d_model * 3, bias=False) self.to_out = ml.Linear(d_model, d_model) def forward(self, x, m): """ Args: x: (b t c) m: (b t c), 1 is data, 0 is padding Returns: x: (b t c) """ h = self.n_heads q, k, v = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b t (h d) -> b t h d", h=h), (q, k, v)) e = einsum("b i h d, b j h d -> b i j h", q, k) e = e * self.scale kpm = m.unsqueeze(1) * m.unsqueeze(2) # b i j 1 if self.causal: with ml.autocast(kpm, torch.bfloat16, torch.float16) as k: kpm = k.squeeze(-1).tril().unsqueeze(-1) # b i j 1 e = e.masked_fill(kpm == 0, -torch.finfo(e.dtype).max) a = e.softmax(dim=2) # Normalize on j, i.e. key o = einsum("b i j h, b j h d -> b i h d", a, v) o = o.flatten(-2) o = self.to_out(o) # b t c o = o * m return o class PrenormResidual(nn.Module): def __init__( self, block, d_model, p_dropout, requires_mask=False, norm_type="ln", n_levels: int | None = None, ): super().__init__() self.block = block self.requires_mask = requires_mask self.norm_type = norm_type if norm_type == "ln": self.norm = nn.LayerNorm(d_model) elif norm_type == "adaln": assert n_levels is not None self.norm = AdaLN(d_model, n_levels) else: raise NotImplementedError(norm_type) self.dropout = nn.Dropout(p_dropout) def forward(self, x, m, l): """ Args: x: input (b t d) m: mask (b t 1), 1 is valuable and 0 is padding l: level to use, required only for AdaLN """ nopts = {"l": l} if self.norm_type == "adaln" else {} bopts = {"m": m} if self.requires_mask else {} x = x + self.dropout(self.block(self.norm(x, **nopts) * m, **bopts)) return x * m class Block(nn.Sequential): def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels, activation_checkpointing=True): super().__init__() self.activation_checkpointing = activation_checkpointing self.attn = PrenormResidual( Attention(d_model, n_heads, causal), d_model=d_model, p_dropout=p_dropout, requires_mask=True, norm_type=norm_type, n_levels=n_levels, ) n_ff = d_model * 4 # 1024 * 4 = 4096 feed-forwards self.ffn = PrenormResidual( nn.Sequential( ml.Linear(d_model, n_ff), nn.GELU(), nn.Dropout(p_dropout), ml.Linear(n_ff, d_model), ), d_model=d_model, p_dropout=p_dropout, norm_type=norm_type, n_levels=n_levels, ) def forward(self, x, m, l): """ Args: x: (b t c) m: (b t 1) l: (b) """ if x.requires_grad and self.activation_checkpointing: x = checkpoint(self.attn, x, m, l, use_reentrant=False) else: x = self.attn(x, m, l) x = self.ffn(x, m, l) return x