added ability to disable activation checkpointing through the YAML (it is very VRAM intensive at double layer size)

This commit is contained in:
mrq 2023-09-05 15:38:21 -05:00
parent 143aee7526
commit 451726fdd5
4 changed files with 31 additions and 27 deletions

View File

@ -225,6 +225,11 @@ class Model:
return 24 return 24
return 12 return 12
@property
def activation_checkpointing(self):
return cfg.trainer.activation_checkpointing
@dataclass() @dataclass()
class Models: class Models:
_max_levels: int = 0 _max_levels: int = 0
@ -420,6 +425,8 @@ class Trainer:
load_module_only: bool = False load_module_only: bool = False
restart_step_count: bool = False restart_step_count: bool = False
activation_checkpointing: bool = True
aggressive_optimizations: bool = False aggressive_optimizations: bool = False
check_for_oom: bool = True check_for_oom: bool = True
gc_mode: str | None = None gc_mode: str | None = None

View File

@ -15,7 +15,8 @@ def get_model(cfg):
d_model=cfg.dim, d_model=cfg.dim,
n_heads=cfg.heads, n_heads=cfg.heads,
n_layers=cfg.layers, n_layers=cfg.layers,
config = cfg
config = cfg,
) )
model._cfg = cfg model._cfg = cfg

View File

@ -141,10 +141,12 @@ class Base(nn.Module):
n_heads: int = 8, n_heads: int = 8,
n_layers: int = 12, n_layers: int = 12,
p_dropout: float = 0.1, p_dropout: float = 0.1,
config = None, config = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
self.n_tokens = n_tokens self.n_tokens = n_tokens
self.d_model = d_model self.d_model = d_model
@ -180,7 +182,7 @@ class Base(nn.Module):
decoder_ffn_embed_dim=d_model * 4, decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers, decoder_layers=n_layers,
dropout=p_dropout, dropout=p_dropout,
checkpoint_activations=True, checkpoint_activations=self.activation_checkpointing,
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0, chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0, recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
@ -282,8 +284,6 @@ class Base(nn.Module):
y: sampled tokens y: sampled tokens
""" """
batch_size = len(text_list)
x_list = self._samplewise_merge_tensors( x_list = self._samplewise_merge_tensors(
self.text_emb(text_list), self.text_emb(text_list),
self.proms_emb(proms_list), self.proms_emb(proms_list),
@ -292,14 +292,15 @@ class Base(nn.Module):
) )
x, m = list_to_tensor(x_list) x, m = list_to_tensor(x_list)
batch_size = len(text_list)
device = x.device device = x.device
if state is not None: if state is not None:
# prefill # prefill
prefill_size = x.shape[1]
# run the initial prompt to fill the KV cache
if len(state) == 0: if len(state) == 0:
prefill_size = x.shape[1]
# run the initial prompt to fill the KV cache
for n in range(prefill_size): for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1) xi = x[:, n, :].unsqueeze(1)
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True) self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True)
@ -312,7 +313,6 @@ class Base(nn.Module):
for block in self.blocks: for block in self.blocks:
x = block(x, m, quant_levels) x = block(x, m, quant_levels)
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
# to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
x = self.classifier(x) * m x = self.classifier(x) * m
@ -327,33 +327,28 @@ class Base(nn.Module):
ignore_sep = torch.tensor(self.ignore_index, device=device) ignore_sep = torch.tensor(self.ignore_index, device=device)
# ignore the prompt when computing loss # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
prom_list = [ prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ]
torch.full_like(t[..., 0], self.ignore_index) for t in proms_list # remake input sequence
] text_prom_list = self._samplewise_merge_tensors( text_list, prom_list, sep=ignore_sep )
# remake input with ignored input prompt
text_prom_list = self._samplewise_merge_tensors(
text_list, prom_list, sep=ignore_sep
)
# process each batch
for i in range(len(text_prom_list)): for i in range(len(text_prom_list)):
# ignore computing loss against text/prompt portion of input # for the NAR, ignore completely computing the loss against the text prompt
# the NAR doesn't need to compute the loss for it
if self.resp_loss_only: if self.resp_loss_only:
text_prom_list[i][:] = self.ignore_index text_prom_list[i][:] = self.ignore_index
# roll the text/prompt for loss computing # for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token
# the AR benefits from this, for some reason I'll figure out later
else: else:
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
text_prom_list[i][-1] = self.ignore_index text_prom_list[i][-1] = self.ignore_index
# for the AR, roll by one and mark the ending with a stop token # adjust the target sequence if needed for the AR
# this coerces the model into properly inferencing causally
# why we don't just append a stop token in the dataloader, who knows
if shift_targ_list: if shift_targ_list:
# creates a copy because this is aliased against input response sequence
targ_list = [*targ_list] targ_list = [*targ_list]
# shift the target response into the future by 1, and mark the rolled back token / last token as a stop token
# this prepares the AR to actually generate autoregressive sequences
for i in range(len(targ_list)): for i in range(len(targ_list)):
targ_list[i] = targ_list[i].roll(-1, dims=0) targ_list[i] = targ_list[i].roll(-1, dims=0)
targ_list[i][-1] = self.stop_token targ_list[i][-1] = self.stop_token
@ -362,6 +357,7 @@ class Base(nn.Module):
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
self.loss = dict( self.loss = dict(
# "nll" was in the original implementation and should actually just be called something else
nll=F.cross_entropy( nll=F.cross_entropy(
torch.cat(h_list), # input / predicted logits torch.cat(h_list), # input / predicted logits
torch.cat(y_list), # target / ground truth torch.cat(y_list), # target / ground truth

View File

@ -153,9 +153,10 @@ class PrenormResidual(nn.Module):
class Block(nn.Sequential): class Block(nn.Sequential):
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels): def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels, activation_checkpointing=True):
super().__init__() super().__init__()
self.activation_checkpointing = activation_checkpointing
self.attn = PrenormResidual( self.attn = PrenormResidual(
Attention(d_model, n_heads, causal), Attention(d_model, n_heads, causal),
d_model=d_model, d_model=d_model,
@ -186,8 +187,7 @@ class Block(nn.Sequential):
m: (b t 1) m: (b t 1)
l: (b) l: (b)
""" """
poor_in_vram = True if x.requires_grad and self.activation_checkpointing:
if x.requires_grad and poor_in_vram:
x = checkpoint(self.attn, x, m, l, use_reentrant=False) x = checkpoint(self.attn, x, m, l, use_reentrant=False)
else: else:
x = self.attn(x, m, l) x = self.attn(x, m, l)