Checkpoint on attention and only sample one layer for NAR
This commit is contained in:
parent
5e4ef084b8
commit
d19449f1f1
|
@ -1,5 +1,4 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Literal, overload
|
||||
|
||||
|
@ -9,6 +8,7 @@ from einops import rearrange
|
|||
from torch import Tensor, einsum, nn
|
||||
from torch.distributions import Categorical
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
def _create_mask(l, device):
|
||||
|
@ -172,7 +172,11 @@ class Block(nn.Sequential):
|
|||
x: (b t c)
|
||||
m: (b t 1)
|
||||
"""
|
||||
x = self.attn(x, m)
|
||||
poor_in_vram = True
|
||||
if x.requires_grad and poor_in_vram:
|
||||
x = checkpoint(self.attn, x, m)
|
||||
else:
|
||||
x = self.attn(x, m)
|
||||
x = self.ffn(x, m)
|
||||
return x
|
||||
|
||||
|
@ -253,6 +257,8 @@ class Base(nn.Module):
|
|||
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=n_prom_levels)
|
||||
|
||||
# +1 to include the stop token
|
||||
# Note that, for different levels, I don't use AdaLN for simplicity
|
||||
# Use different embeddings might be enough.
|
||||
self.resp_embs = nn.ModuleList(
|
||||
[Embedding(n_resp_tokens, d_model) for _ in range(n_levels)]
|
||||
)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import random
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
@ -62,25 +64,48 @@ class NAR(Base):
|
|||
else:
|
||||
assert resps_list is not None
|
||||
|
||||
loss = {}
|
||||
resp_list = [o[..., 0] for o in resps_list]
|
||||
hyp_resp_lists = [resp_list]
|
||||
for i in range(self.n_levels):
|
||||
resp_list = [o[..., 0] for o in resps_list]
|
||||
next_resp_list = [o[..., i + 1] for o in resps_list]
|
||||
hyp_resp_list = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
resp_list,
|
||||
next_resp_list,
|
||||
return_all_resp=True,
|
||||
shift_targ_list=False,
|
||||
quant_level=i,
|
||||
)
|
||||
hyp_resp_lists.append(hyp_resp_list)
|
||||
loss |= {f"l{i}": self.loss}
|
||||
del self.loss
|
||||
self.loss = loss
|
||||
# I noticed that VALL-E randomly sample a layer,
|
||||
# which will be more memory efficient, let's do it.
|
||||
# For simplicity, do it on per batch level instead of per sample level
|
||||
# does that matter?
|
||||
|
||||
# Old code:
|
||||
# loss = {}
|
||||
# resp_list = [o[..., 0] for o in resps_list]
|
||||
# hyp_resp_lists = [resp_list]
|
||||
# for i in range(self.n_levels):
|
||||
# resp_list = [o[..., 0] for o in resps_list]
|
||||
# next_resp_list = [o[..., i + 1] for o in resps_list]
|
||||
# hyp_resp_list = super().forward(
|
||||
# text_list,
|
||||
# proms_list,
|
||||
# resp_list,
|
||||
# next_resp_list,
|
||||
# return_all_resp=True,
|
||||
# shift_targ_list=False,
|
||||
# quant_level=i,
|
||||
# )
|
||||
# hyp_resp_lists.append(hyp_resp_list)
|
||||
# loss |= {f"l{i}": self.loss}
|
||||
# del self.loss
|
||||
# self.loss = loss
|
||||
|
||||
quant_level = random.randint(0, self.n_levels - 1)
|
||||
cur_resp_list = [o[..., quant_level] for o in resps_list]
|
||||
next_resp_list = [o[..., quant_level + 1] for o in resps_list]
|
||||
|
||||
_ = super().forward(
|
||||
text_list,
|
||||
proms_list,
|
||||
cur_resp_list,
|
||||
next_resp_list,
|
||||
return_all_resp=True,
|
||||
shift_targ_list=False,
|
||||
quant_level=quant_level,
|
||||
)
|
||||
|
||||
# Yes, just nothing as we are training
|
||||
hyp_resp_lists = []
|
||||
|
||||
hyp_resps_list = [
|
||||
*map(lambda ts: torch.stack(ts, dim=-1), zip(*hyp_resp_lists))
|
||||
|
|
Loading…
Reference in New Issue
Block a user