Checkpoint on attention and only sample one layer for NAR

This commit is contained in:
enhuiz 2023-01-12 14:41:30 +08:00
parent 5e4ef084b8
commit d19449f1f1
2 changed files with 52 additions and 21 deletions

View File

@ -1,5 +1,4 @@
import math
from dataclasses import dataclass
from functools import partial
from typing import Literal, overload
@ -9,6 +8,7 @@ from einops import rearrange
from torch import Tensor, einsum, nn
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
def _create_mask(l, device):
@ -172,7 +172,11 @@ class Block(nn.Sequential):
x: (b t c)
m: (b t 1)
"""
x = self.attn(x, m)
poor_in_vram = True
if x.requires_grad and poor_in_vram:
x = checkpoint(self.attn, x, m)
else:
x = self.attn(x, m)
x = self.ffn(x, m)
return x
@ -253,6 +257,8 @@ class Base(nn.Module):
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=n_prom_levels)
# +1 to include the stop token
# Note that, for different levels, I don't use AdaLN for simplicity
# Use different embeddings might be enough.
self.resp_embs = nn.ModuleList(
[Embedding(n_resp_tokens, d_model) for _ in range(n_levels)]
)

View File

@ -1,3 +1,5 @@
import random
import torch
from einops import rearrange
from torch import Tensor
@ -62,25 +64,48 @@ class NAR(Base):
else:
assert resps_list is not None
loss = {}
resp_list = [o[..., 0] for o in resps_list]
hyp_resp_lists = [resp_list]
for i in range(self.n_levels):
resp_list = [o[..., 0] for o in resps_list]
next_resp_list = [o[..., i + 1] for o in resps_list]
hyp_resp_list = super().forward(
text_list,
proms_list,
resp_list,
next_resp_list,
return_all_resp=True,
shift_targ_list=False,
quant_level=i,
)
hyp_resp_lists.append(hyp_resp_list)
loss |= {f"l{i}": self.loss}
del self.loss
self.loss = loss
# I noticed that VALL-E randomly sample a layer,
# which will be more memory efficient, let's do it.
# For simplicity, do it on per batch level instead of per sample level
# does that matter?
# Old code:
# loss = {}
# resp_list = [o[..., 0] for o in resps_list]
# hyp_resp_lists = [resp_list]
# for i in range(self.n_levels):
# resp_list = [o[..., 0] for o in resps_list]
# next_resp_list = [o[..., i + 1] for o in resps_list]
# hyp_resp_list = super().forward(
# text_list,
# proms_list,
# resp_list,
# next_resp_list,
# return_all_resp=True,
# shift_targ_list=False,
# quant_level=i,
# )
# hyp_resp_lists.append(hyp_resp_list)
# loss |= {f"l{i}": self.loss}
# del self.loss
# self.loss = loss
quant_level = random.randint(0, self.n_levels - 1)
cur_resp_list = [o[..., quant_level] for o in resps_list]
next_resp_list = [o[..., quant_level + 1] for o in resps_list]
_ = super().forward(
text_list,
proms_list,
cur_resp_list,
next_resp_list,
return_all_resp=True,
shift_targ_list=False,
quant_level=quant_level,
)
# Yes, just nothing as we are training
hyp_resp_lists = []
hyp_resps_list = [
*map(lambda ts: torch.stack(ts, dim=-1), zip(*hyp_resp_lists))