Checkpoint on attention and only sample one layer for NAR
This commit is contained in:
parent
5e4ef084b8
commit
d19449f1f1
|
@ -1,5 +1,4 @@
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Literal, overload
|
from typing import Literal, overload
|
||||||
|
|
||||||
|
@ -9,6 +8,7 @@ from einops import rearrange
|
||||||
from torch import Tensor, einsum, nn
|
from torch import Tensor, einsum, nn
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
|
||||||
def _create_mask(l, device):
|
def _create_mask(l, device):
|
||||||
|
@ -172,6 +172,10 @@ class Block(nn.Sequential):
|
||||||
x: (b t c)
|
x: (b t c)
|
||||||
m: (b t 1)
|
m: (b t 1)
|
||||||
"""
|
"""
|
||||||
|
poor_in_vram = True
|
||||||
|
if x.requires_grad and poor_in_vram:
|
||||||
|
x = checkpoint(self.attn, x, m)
|
||||||
|
else:
|
||||||
x = self.attn(x, m)
|
x = self.attn(x, m)
|
||||||
x = self.ffn(x, m)
|
x = self.ffn(x, m)
|
||||||
return x
|
return x
|
||||||
|
@ -253,6 +257,8 @@ class Base(nn.Module):
|
||||||
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=n_prom_levels)
|
self.prom_emb = MultiEmbedding(n_tokens, d_model, n_levels=n_prom_levels)
|
||||||
|
|
||||||
# +1 to include the stop token
|
# +1 to include the stop token
|
||||||
|
# Note that, for different levels, I don't use AdaLN for simplicity
|
||||||
|
# Use different embeddings might be enough.
|
||||||
self.resp_embs = nn.ModuleList(
|
self.resp_embs = nn.ModuleList(
|
||||||
[Embedding(n_resp_tokens, d_model) for _ in range(n_levels)]
|
[Embedding(n_resp_tokens, d_model) for _ in range(n_levels)]
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -62,25 +64,48 @@ class NAR(Base):
|
||||||
else:
|
else:
|
||||||
assert resps_list is not None
|
assert resps_list is not None
|
||||||
|
|
||||||
loss = {}
|
# I noticed that VALL-E randomly sample a layer,
|
||||||
resp_list = [o[..., 0] for o in resps_list]
|
# which will be more memory efficient, let's do it.
|
||||||
hyp_resp_lists = [resp_list]
|
# For simplicity, do it on per batch level instead of per sample level
|
||||||
for i in range(self.n_levels):
|
# does that matter?
|
||||||
resp_list = [o[..., 0] for o in resps_list]
|
|
||||||
next_resp_list = [o[..., i + 1] for o in resps_list]
|
# Old code:
|
||||||
hyp_resp_list = super().forward(
|
# loss = {}
|
||||||
|
# resp_list = [o[..., 0] for o in resps_list]
|
||||||
|
# hyp_resp_lists = [resp_list]
|
||||||
|
# for i in range(self.n_levels):
|
||||||
|
# resp_list = [o[..., 0] for o in resps_list]
|
||||||
|
# next_resp_list = [o[..., i + 1] for o in resps_list]
|
||||||
|
# hyp_resp_list = super().forward(
|
||||||
|
# text_list,
|
||||||
|
# proms_list,
|
||||||
|
# resp_list,
|
||||||
|
# next_resp_list,
|
||||||
|
# return_all_resp=True,
|
||||||
|
# shift_targ_list=False,
|
||||||
|
# quant_level=i,
|
||||||
|
# )
|
||||||
|
# hyp_resp_lists.append(hyp_resp_list)
|
||||||
|
# loss |= {f"l{i}": self.loss}
|
||||||
|
# del self.loss
|
||||||
|
# self.loss = loss
|
||||||
|
|
||||||
|
quant_level = random.randint(0, self.n_levels - 1)
|
||||||
|
cur_resp_list = [o[..., quant_level] for o in resps_list]
|
||||||
|
next_resp_list = [o[..., quant_level + 1] for o in resps_list]
|
||||||
|
|
||||||
|
_ = super().forward(
|
||||||
text_list,
|
text_list,
|
||||||
proms_list,
|
proms_list,
|
||||||
resp_list,
|
cur_resp_list,
|
||||||
next_resp_list,
|
next_resp_list,
|
||||||
return_all_resp=True,
|
return_all_resp=True,
|
||||||
shift_targ_list=False,
|
shift_targ_list=False,
|
||||||
quant_level=i,
|
quant_level=quant_level,
|
||||||
)
|
)
|
||||||
hyp_resp_lists.append(hyp_resp_list)
|
|
||||||
loss |= {f"l{i}": self.loss}
|
# Yes, just nothing as we are training
|
||||||
del self.loss
|
hyp_resp_lists = []
|
||||||
self.loss = loss
|
|
||||||
|
|
||||||
hyp_resps_list = [
|
hyp_resps_list = [
|
||||||
*map(lambda ts: torch.stack(ts, dim=-1), zip(*hyp_resp_lists))
|
*map(lambda ts: torch.stack(ts, dim=-1), zip(*hyp_resp_lists))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user