added ability to disable activation checkpointing through the YAML (it is very VRAM intensive at double layer size)
This commit is contained in:
parent
143aee7526
commit
451726fdd5
|
@ -225,6 +225,11 @@ class Model:
|
|||
return 24
|
||||
return 12
|
||||
|
||||
@property
|
||||
def activation_checkpointing(self):
|
||||
return cfg.trainer.activation_checkpointing
|
||||
|
||||
|
||||
@dataclass()
|
||||
class Models:
|
||||
_max_levels: int = 0
|
||||
|
@ -420,6 +425,8 @@ class Trainer:
|
|||
load_module_only: bool = False
|
||||
restart_step_count: bool = False
|
||||
|
||||
activation_checkpointing: bool = True
|
||||
|
||||
aggressive_optimizations: bool = False
|
||||
check_for_oom: bool = True
|
||||
gc_mode: str | None = None
|
||||
|
|
|
@ -15,7 +15,8 @@ def get_model(cfg):
|
|||
d_model=cfg.dim,
|
||||
n_heads=cfg.heads,
|
||||
n_layers=cfg.layers,
|
||||
config = cfg
|
||||
|
||||
config = cfg,
|
||||
)
|
||||
model._cfg = cfg
|
||||
|
||||
|
|
|
@ -141,10 +141,12 @@ class Base(nn.Module):
|
|||
n_heads: int = 8,
|
||||
n_layers: int = 12,
|
||||
p_dropout: float = 0.1,
|
||||
|
||||
config = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
|
||||
|
||||
self.n_tokens = n_tokens
|
||||
self.d_model = d_model
|
||||
|
@ -180,7 +182,7 @@ class Base(nn.Module):
|
|||
decoder_ffn_embed_dim=d_model * 4,
|
||||
decoder_layers=n_layers,
|
||||
dropout=p_dropout,
|
||||
checkpoint_activations=True,
|
||||
checkpoint_activations=self.activation_checkpointing,
|
||||
|
||||
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||
|
@ -282,8 +284,6 @@ class Base(nn.Module):
|
|||
y: sampled tokens
|
||||
"""
|
||||
|
||||
batch_size = len(text_list)
|
||||
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
|
@ -292,14 +292,15 @@ class Base(nn.Module):
|
|||
)
|
||||
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
batch_size = len(text_list)
|
||||
device = x.device
|
||||
|
||||
if state is not None:
|
||||
# prefill
|
||||
prefill_size = x.shape[1]
|
||||
|
||||
# run the initial prompt to fill the KV cache
|
||||
if len(state) == 0:
|
||||
prefill_size = x.shape[1]
|
||||
# run the initial prompt to fill the KV cache
|
||||
for n in range(prefill_size):
|
||||
xi = x[:, n, :].unsqueeze(1)
|
||||
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
||||
|
@ -312,7 +313,6 @@ class Base(nn.Module):
|
|||
for block in self.blocks:
|
||||
x = block(x, m, quant_levels)
|
||||
elif self.arch_type == "retnet":
|
||||
# to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward
|
||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||
|
||||
x = self.classifier(x) * m
|
||||
|
@ -327,33 +327,28 @@ class Base(nn.Module):
|
|||
|
||||
ignore_sep = torch.tensor(self.ignore_index, device=device)
|
||||
|
||||
# ignore the prompt when computing loss
|
||||
prom_list = [
|
||||
torch.full_like(t[..., 0], self.ignore_index) for t in proms_list
|
||||
]
|
||||
# remake input with ignored input prompt
|
||||
text_prom_list = self._samplewise_merge_tensors(
|
||||
text_list, prom_list, sep=ignore_sep
|
||||
)
|
||||
# create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
|
||||
prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ]
|
||||
# remake input sequence
|
||||
text_prom_list = self._samplewise_merge_tensors( text_list, prom_list, sep=ignore_sep )
|
||||
|
||||
# process each batch
|
||||
for i in range(len(text_prom_list)):
|
||||
# ignore computing loss against text/prompt portion of input
|
||||
# the NAR doesn't need to compute the loss for it
|
||||
# for the NAR, ignore completely computing the loss against the text prompt
|
||||
if self.resp_loss_only:
|
||||
text_prom_list[i][:] = self.ignore_index
|
||||
|
||||
# roll the text/prompt for loss computing
|
||||
# the AR benefits from this, for some reason I'll figure out later
|
||||
# for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token
|
||||
else:
|
||||
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
|
||||
text_prom_list[i][-1] = self.ignore_index
|
||||
|
||||
# for the AR, roll by one and mark the ending with a stop token
|
||||
# this coerces the model into properly inferencing causally
|
||||
|
||||
# why we don't just append a stop token in the dataloader, who knows
|
||||
# adjust the target sequence if needed for the AR
|
||||
if shift_targ_list:
|
||||
# creates a copy because this is aliased against input response sequence
|
||||
targ_list = [*targ_list]
|
||||
# shift the target response into the future by 1, and mark the rolled back token / last token as a stop token
|
||||
# this prepares the AR to actually generate autoregressive sequences
|
||||
for i in range(len(targ_list)):
|
||||
targ_list[i] = targ_list[i].roll(-1, dims=0)
|
||||
targ_list[i][-1] = self.stop_token
|
||||
|
@ -362,6 +357,7 @@ class Base(nn.Module):
|
|||
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
|
||||
|
||||
self.loss = dict(
|
||||
# "nll" was in the original implementation and should actually just be called something else
|
||||
nll=F.cross_entropy(
|
||||
torch.cat(h_list), # input / predicted logits
|
||||
torch.cat(y_list), # target / ground truth
|
||||
|
|
|
@ -153,9 +153,10 @@ class PrenormResidual(nn.Module):
|
|||
|
||||
|
||||
class Block(nn.Sequential):
|
||||
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels):
|
||||
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels, activation_checkpointing=True):
|
||||
super().__init__()
|
||||
|
||||
self.activation_checkpointing = activation_checkpointing
|
||||
self.attn = PrenormResidual(
|
||||
Attention(d_model, n_heads, causal),
|
||||
d_model=d_model,
|
||||
|
@ -186,8 +187,7 @@ class Block(nn.Sequential):
|
|||
m: (b t 1)
|
||||
l: (b)
|
||||
"""
|
||||
poor_in_vram = True
|
||||
if x.requires_grad and poor_in_vram:
|
||||
if x.requires_grad and self.activation_checkpointing:
|
||||
x = checkpoint(self.attn, x, m, l, use_reentrant=False)
|
||||
else:
|
||||
x = self.attn(x, m, l)
|
||||
|
|
Loading…
Reference in New Issue
Block a user