added ability to disable activation checkpointing through the YAML (it is very VRAM intensive at double layer size)

This commit is contained in:
mrq 2023-09-05 15:38:21 -05:00
parent 143aee7526
commit 451726fdd5
4 changed files with 31 additions and 27 deletions

View File

@ -225,6 +225,11 @@ class Model:
return 24
return 12
@property
def activation_checkpointing(self):
return cfg.trainer.activation_checkpointing
@dataclass()
class Models:
_max_levels: int = 0
@ -420,6 +425,8 @@ class Trainer:
load_module_only: bool = False
restart_step_count: bool = False
activation_checkpointing: bool = True
aggressive_optimizations: bool = False
check_for_oom: bool = True
gc_mode: str | None = None

View File

@ -15,7 +15,8 @@ def get_model(cfg):
d_model=cfg.dim,
n_heads=cfg.heads,
n_layers=cfg.layers,
config = cfg
config = cfg,
)
model._cfg = cfg

View File

@ -141,10 +141,12 @@ class Base(nn.Module):
n_heads: int = 8,
n_layers: int = 12,
p_dropout: float = 0.1,
config = None,
):
super().__init__()
self.config = config
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
self.n_tokens = n_tokens
self.d_model = d_model
@ -180,7 +182,7 @@ class Base(nn.Module):
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout,
checkpoint_activations=True,
checkpoint_activations=self.activation_checkpointing,
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
@ -282,8 +284,6 @@ class Base(nn.Module):
y: sampled tokens
"""
batch_size = len(text_list)
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
@ -292,14 +292,15 @@ class Base(nn.Module):
)
x, m = list_to_tensor(x_list)
batch_size = len(text_list)
device = x.device
if state is not None:
# prefill
prefill_size = x.shape[1]
# run the initial prompt to fill the KV cache
if len(state) == 0:
prefill_size = x.shape[1]
# run the initial prompt to fill the KV cache
for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1)
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True)
@ -312,7 +313,6 @@ class Base(nn.Module):
for block in self.blocks:
x = block(x, m, quant_levels)
elif self.arch_type == "retnet":
# to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
x = self.classifier(x) * m
@ -327,33 +327,28 @@ class Base(nn.Module):
ignore_sep = torch.tensor(self.ignore_index, device=device)
# ignore the prompt when computing loss
prom_list = [
torch.full_like(t[..., 0], self.ignore_index) for t in proms_list
]
# remake input with ignored input prompt
text_prom_list = self._samplewise_merge_tensors(
text_list, prom_list, sep=ignore_sep
)
# create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ]
# remake input sequence
text_prom_list = self._samplewise_merge_tensors( text_list, prom_list, sep=ignore_sep )
# process each batch
for i in range(len(text_prom_list)):
# ignore computing loss against text/prompt portion of input
# the NAR doesn't need to compute the loss for it
# for the NAR, ignore completely computing the loss against the text prompt
if self.resp_loss_only:
text_prom_list[i][:] = self.ignore_index
# roll the text/prompt for loss computing
# the AR benefits from this, for some reason I'll figure out later
# for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token
else:
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
text_prom_list[i][-1] = self.ignore_index
# for the AR, roll by one and mark the ending with a stop token
# this coerces the model into properly inferencing causally
# why we don't just append a stop token in the dataloader, who knows
# adjust the target sequence if needed for the AR
if shift_targ_list:
# creates a copy because this is aliased against input response sequence
targ_list = [*targ_list]
# shift the target response into the future by 1, and mark the rolled back token / last token as a stop token
# this prepares the AR to actually generate autoregressive sequences
for i in range(len(targ_list)):
targ_list[i] = targ_list[i].roll(-1, dims=0)
targ_list[i][-1] = self.stop_token
@ -362,6 +357,7 @@ class Base(nn.Module):
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
self.loss = dict(
# "nll" was in the original implementation and should actually just be called something else
nll=F.cross_entropy(
torch.cat(h_list), # input / predicted logits
torch.cat(y_list), # target / ground truth

View File

@ -153,9 +153,10 @@ class PrenormResidual(nn.Module):
class Block(nn.Sequential):
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels):
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels, activation_checkpointing=True):
super().__init__()
self.activation_checkpointing = activation_checkpointing
self.attn = PrenormResidual(
Attention(d_model, n_heads, causal),
d_model=d_model,
@ -186,8 +187,7 @@ class Block(nn.Sequential):
m: (b t 1)
l: (b)
"""
poor_in_vram = True
if x.requires_grad and poor_in_vram:
if x.requires_grad and self.activation_checkpointing:
x = checkpoint(self.attn, x, m, l, use_reentrant=False)
else:
x = self.attn(x, m, l)