diff --git a/vall_e/vall_e/base.py b/vall_e/vall_e/base.py index 7566f67..da706b2 100644 --- a/vall_e/vall_e/base.py +++ b/vall_e/vall_e/base.py @@ -1,5 +1,4 @@ import math -from dataclasses import dataclass from functools import partial from typing import Literal, overload @@ -9,6 +8,7 @@ from einops import rearrange from torch import Tensor, einsum, nn from torch.distributions import Categorical from torch.nn.utils.rnn import pad_sequence +from torch.utils.checkpoint import checkpoint def _create_mask(l, device): @@ -172,7 +172,11 @@ class Block(nn.Sequential): x: (b t c) m: (b t 1) """ - x = self.attn(x, m) + poor_in_vram = True + if x.requires_grad and poor_in_vram: + x = checkpoint(self.attn, x, m) + else: + x = self.attn(x, m) x = self.ffn(x, m) return x @@ -253,6 +257,8 @@ class Base(nn.Module): self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=n_prom_levels) # +1 to include the stop token + # Note that, for different levels, I don't use AdaLN for simplicity + # Use different embeddings might be enough. self.resp_embs = nn.ModuleList( [Embedding(n_resp_tokens, d_model) for _ in range(n_levels)] ) diff --git a/vall_e/vall_e/nar.py b/vall_e/vall_e/nar.py index b69d8cb..fc2670d 100644 --- a/vall_e/vall_e/nar.py +++ b/vall_e/vall_e/nar.py @@ -1,3 +1,5 @@ +import random + import torch from einops import rearrange from torch import Tensor @@ -62,25 +64,48 @@ class NAR(Base): else: assert resps_list is not None - loss = {} - resp_list = [o[..., 0] for o in resps_list] - hyp_resp_lists = [resp_list] - for i in range(self.n_levels): - resp_list = [o[..., 0] for o in resps_list] - next_resp_list = [o[..., i + 1] for o in resps_list] - hyp_resp_list = super().forward( - text_list, - proms_list, - resp_list, - next_resp_list, - return_all_resp=True, - shift_targ_list=False, - quant_level=i, - ) - hyp_resp_lists.append(hyp_resp_list) - loss |= {f"l{i}": self.loss} - del self.loss - self.loss = loss + # I noticed that VALL-E randomly sample a layer, + # which will be more memory efficient, let's do it. + # For simplicity, do it on per batch level instead of per sample level + # does that matter? + + # Old code: + # loss = {} + # resp_list = [o[..., 0] for o in resps_list] + # hyp_resp_lists = [resp_list] + # for i in range(self.n_levels): + # resp_list = [o[..., 0] for o in resps_list] + # next_resp_list = [o[..., i + 1] for o in resps_list] + # hyp_resp_list = super().forward( + # text_list, + # proms_list, + # resp_list, + # next_resp_list, + # return_all_resp=True, + # shift_targ_list=False, + # quant_level=i, + # ) + # hyp_resp_lists.append(hyp_resp_list) + # loss |= {f"l{i}": self.loss} + # del self.loss + # self.loss = loss + + quant_level = random.randint(0, self.n_levels - 1) + cur_resp_list = [o[..., quant_level] for o in resps_list] + next_resp_list = [o[..., quant_level + 1] for o in resps_list] + + _ = super().forward( + text_list, + proms_list, + cur_resp_list, + next_resp_list, + return_all_resp=True, + shift_targ_list=False, + quant_level=quant_level, + ) + + # Yes, just nothing as we are training + hyp_resp_lists = [] hyp_resps_list = [ *map(lambda ts: torch.stack(ts, dim=-1), zip(*hyp_resp_lists))