diff --git a/codes/models/lucidrains/vq.py b/codes/models/lucidrains/vq.py new file mode 100644 index 00000000..4b1019ec --- /dev/null +++ b/codes/models/lucidrains/vq.py @@ -0,0 +1,444 @@ +import functools + +import torch +from torch import nn, einsum +import torch.nn.functional as F +import torch.distributed as distributed +from torch.cuda.amp import autocast + +from einops import rearrange, repeat +from contextlib import contextmanager + + +def par(t, nm): + print(f'grad report {nm}: {t}') + return t + +def reg(t, nm): + l = torch.tensor([0], requires_grad=True, device=t.device, dtype=torch.float) + l.register_hook(functools.partial(par, nm=nm)) + t = t + l + return t + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def noop(*args, **kwargs): + pass + +def l2norm(t): + return F.normalize(t, p = 2, dim = -1) + +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + +def gumbel_sample(t, temperature = 1., dim = -1): + if temperature == 0: + return t.argmax(dim = dim) + + return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) + +def laplace_smoothing(x, n_categories, eps = 1e-5): + return (x + eps) / (x.sum() + n_categories * eps) + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device = device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device = device) + + return samples[indices] + +def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): + dim, dtype, device = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs ** 2).sum(dim = -1) + + buckets = dists.max(dim = -1).indices + bins = torch.bincount(buckets, minlength = num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = l2norm(new_means) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + +# regularization losses + +def orthgonal_loss_fn(t): + # eq (2) from https://arxiv.org/abs/2112.00384 + n = t.shape[0] + normed_codes = l2norm(t) + identity = torch.eye(n, device = t.device) + cosine_sim = einsum('i d, j d -> i j', normed_codes, normed_codes) + return ((cosine_sim - identity) ** 2).sum() / (n ** 2) + +# distance types + +class EuclideanCodebook(nn.Module): + def __init__( + self, + dim, + codebook_size, + kmeans_init = False, + kmeans_iters = 10, + decay = 0.8, + eps = 1e-5, + threshold_ema_dead_code = 2, + use_ddp = False, + learnable_codebook = False, + sample_codebook_temp = 0 + ): + super().__init__() + self.decay = decay + init_fn = torch.randn if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + self.kmeans_iters = kmeans_iters + self.eps = eps + self.threshold_ema_dead_code = threshold_ema_dead_code + self.sample_codebook_temp = sample_codebook_temp + + self.all_reduce_fn = distributed.all_reduce if use_ddp else noop + + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.register_buffer('cluster_size', torch.zeros(codebook_size)) + self.register_buffer('embed_avg', embed.clone()) + + self.learnable_codebook = learnable_codebook + if learnable_codebook: + self.embed = nn.Parameter(embed) + else: + self.register_buffer('embed', embed) + + @torch.jit.ignore + def init_embed_(self, data): + if self.initted: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + def replace(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], + sample_vectors(samples, self.codebook_size), + self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + batch_samples = rearrange(batch_samples, '... d -> (...) d') + self.replace(batch_samples, mask = expired_codes) + + @autocast(enabled = False) + def forward(self, x, used_codes=[]): + shape, dtype = x.shape, x.dtype + flatten = rearrange(x, '... d -> (...) d') + + self.init_embed_(flatten) + + embed = self.embed if not self.learnable_codebook else self.embed.detach() + embed = embed.t() + + dist = -( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + + for uc in used_codes: + mask = torch.arange(0, self.codebook_size, device=x.device).unsqueeze(0).repeat(x.shape[0],1) == uc.unsqueeze(1) + dist[mask] = -torch.inf + embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = embed_ind.view(*shape[:-1]) + quantize = F.embedding(embed_ind, self.embed) + + # Perform the gumbel trick on the end result (during training) + if self.training: + quantize = flatten + (quantize - flatten).detach() + + if self.training: + cluster_size = embed_onehot.sum(0) + self.all_reduce_fn(cluster_size) + + ema_inplace(self.cluster_size, cluster_size, self.decay) + + embed_sum = flatten.t() @ embed_onehot + self.all_reduce_fn(embed_sum) + + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum() + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + self.expire_codes_(x) + + return quantize, embed_ind + +class CosineSimCodebook(nn.Module): + def __init__( + self, + dim, + codebook_size, + kmeans_init = False, + kmeans_iters = 10, + decay = 0.8, + eps = 1e-5, + threshold_ema_dead_code = 2, + use_ddp = False, + learnable_codebook = False, + sample_codebook_temp = 0. + ): + super().__init__() + self.decay = decay + + if not kmeans_init: + embed = l2norm(torch.randn(codebook_size, dim)) + else: + embed = torch.zeros(codebook_size, dim) + + self.codebook_size = codebook_size + self.kmeans_iters = kmeans_iters + self.eps = eps + self.threshold_ema_dead_code = threshold_ema_dead_code + self.sample_codebook_temp = sample_codebook_temp + + self.all_reduce_fn = distributed.all_reduce if use_ddp else noop + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.register_buffer('cluster_size', torch.zeros(codebook_size)) + + self.learnable_codebook = learnable_codebook + if learnable_codebook: + self.embed = nn.Parameter(embed) + else: + self.register_buffer('embed', embed) + + @torch.jit.ignore + def init_embed_(self, data): + if self.initted: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters, + use_cosine_sim = True) + self.embed.data.copy_(embed) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + def replace(self, samples, mask): + samples = l2norm(samples) + modified_codebook = torch.where( + mask[..., None], + sample_vectors(samples, self.codebook_size), + self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + batch_samples = rearrange(batch_samples, '... d -> (...) d') + self.replace(batch_samples, mask = expired_codes) + + @autocast(enabled = False) + def forward(self, x, used_codes=[]): + shape, dtype = x.shape, x.dtype + flatten = rearrange(x, '... d -> (...) d') + flatten = l2norm(flatten) + + self.init_embed_(flatten) + embed = self.embed if not self.learnable_codebook else self.embed.detach() + embed = l2norm(embed) + + dist = flatten @ embed.t() + for uc in used_codes: + mask = torch.arange(0, self.codebook_size, device=x.device).unsqueeze(0).repeat(x.shape[0],1) == uc.unsqueeze(1) + dist[mask] = -torch.inf + embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = embed_ind.view(*shape[:-1]) + + quantize = F.embedding(embed_ind, self.embed) + # Perform the gumbel trick on the end result (during training) + if self.training: + quantize = flatten + (quantize - flatten).detach() + + if self.training: + bins = embed_onehot.sum(0) + self.all_reduce_fn(bins) + + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = flatten.t() @ embed_onehot + self.all_reduce_fn(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = l2norm(embed_normalized) + embed_normalized = torch.where(zero_mask[..., None], embed, + embed_normalized) + ema_inplace(self.embed, embed_normalized, self.decay) + self.expire_codes_(x) + + return quantize, embed_ind + +# main class + +class VectorQuantize(nn.Module): + def __init__( + self, + dim, + codebook_size, + n_embed = None, + codebook_dim = None, + decay = 0.8, + eps = 1e-5, + kmeans_init = False, + kmeans_iters = 10, + use_cosine_sim = False, + threshold_ema_dead_code = 0, + channel_last = True, + accept_image_fmap = False, + commitment_weight = None, + commitment = 1., # deprecate in next version, turn off by default + orthogonal_reg_weight = 0., + orthogonal_reg_active_codes_only = False, + orthogonal_reg_max_codes = None, + sample_codebook_temp = 0., + sync_codebook = False + ): + super().__init__() + n_embed = default(n_embed, codebook_size) + + codebook_dim = default(codebook_dim, dim) + requires_projection = codebook_dim != dim + self.project_in = nn.Linear(dim, codebook_dim) if requires_projection \ + else nn.Identity() + self.project_out = nn.Linear(codebook_dim, dim) if requires_projection \ + else nn.Identity() + + self.eps = eps + self.commitment_weight = default(commitment_weight, commitment) + + has_codebook_orthogonal_loss = orthogonal_reg_weight > 0 + self.orthogonal_reg_weight = orthogonal_reg_weight + self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only + self.orthogonal_reg_max_codes = orthogonal_reg_max_codes + + codebook_class = EuclideanCodebook if not use_cosine_sim \ + else CosineSimCodebook + + self._codebook = codebook_class( + dim = codebook_dim, + codebook_size = n_embed, + kmeans_init = kmeans_init, + kmeans_iters = kmeans_iters, + decay = decay, + eps = eps, + threshold_ema_dead_code = threshold_ema_dead_code, + use_ddp = sync_codebook, + learnable_codebook = has_codebook_orthogonal_loss, + sample_codebook_temp = sample_codebook_temp + ) + + self.codebook_size = codebook_size + + self.accept_image_fmap = accept_image_fmap + self.channel_last = channel_last + + @property + def codebook(self): + return self._codebook.embed + + def forward(self, x, used_codes=None): + shape, device, codebook_size = x.shape, x.device, self.codebook_size + + need_transpose = not self.channel_last and not self.accept_image_fmap + + if self.accept_image_fmap: + height, width = x.shape[-2:] + x = rearrange(x, 'b c h w -> b (h w) c') + + if need_transpose: + x = rearrange(x, 'b d n -> b n d') + + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x, used_codes) + + loss = torch.tensor([0.], device = device, requires_grad = self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + if self.orthogonal_reg_weight > 0: + codebook = self.codebook + + if self.orthogonal_reg_active_codes_only: + # only calculate orthogonal loss for the activated codes for this batch + unique_code_ids = torch.unique(embed_ind) + codebook = codebook[unique_code_ids] + + num_codes = codebook.shape[0] + if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: + rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes] + codebook = codebook[rand_ids] + + orthogonal_reg_loss = orthgonal_loss_fn(codebook) + loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight + + quantize = self.project_out(quantize) + + if need_transpose: + quantize = rearrange(quantize, 'b n d -> b d n') + + if self.accept_image_fmap: + quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width) + embed_ind = rearrange(embed_ind, 'b (h w) -> b h w', h = height, w = width) + + return quantize, embed_ind, loss