added ability to disable activation checkpointing through the YAML (it is very VRAM intensive at double layer size)
This commit is contained in:
parent
143aee7526
commit
451726fdd5
|
@ -225,6 +225,11 @@ class Model:
|
||||||
return 24
|
return 24
|
||||||
return 12
|
return 12
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_checkpointing(self):
|
||||||
|
return cfg.trainer.activation_checkpointing
|
||||||
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Models:
|
class Models:
|
||||||
_max_levels: int = 0
|
_max_levels: int = 0
|
||||||
|
@ -420,6 +425,8 @@ class Trainer:
|
||||||
load_module_only: bool = False
|
load_module_only: bool = False
|
||||||
restart_step_count: bool = False
|
restart_step_count: bool = False
|
||||||
|
|
||||||
|
activation_checkpointing: bool = True
|
||||||
|
|
||||||
aggressive_optimizations: bool = False
|
aggressive_optimizations: bool = False
|
||||||
check_for_oom: bool = True
|
check_for_oom: bool = True
|
||||||
gc_mode: str | None = None
|
gc_mode: str | None = None
|
||||||
|
|
|
@ -15,7 +15,8 @@ def get_model(cfg):
|
||||||
d_model=cfg.dim,
|
d_model=cfg.dim,
|
||||||
n_heads=cfg.heads,
|
n_heads=cfg.heads,
|
||||||
n_layers=cfg.layers,
|
n_layers=cfg.layers,
|
||||||
config = cfg
|
|
||||||
|
config = cfg,
|
||||||
)
|
)
|
||||||
model._cfg = cfg
|
model._cfg = cfg
|
||||||
|
|
||||||
|
|
|
@ -141,10 +141,12 @@ class Base(nn.Module):
|
||||||
n_heads: int = 8,
|
n_heads: int = 8,
|
||||||
n_layers: int = 12,
|
n_layers: int = 12,
|
||||||
p_dropout: float = 0.1,
|
p_dropout: float = 0.1,
|
||||||
|
|
||||||
config = None,
|
config = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
|
||||||
|
|
||||||
self.n_tokens = n_tokens
|
self.n_tokens = n_tokens
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
@ -180,7 +182,7 @@ class Base(nn.Module):
|
||||||
decoder_ffn_embed_dim=d_model * 4,
|
decoder_ffn_embed_dim=d_model * 4,
|
||||||
decoder_layers=n_layers,
|
decoder_layers=n_layers,
|
||||||
dropout=p_dropout,
|
dropout=p_dropout,
|
||||||
checkpoint_activations=True,
|
checkpoint_activations=self.activation_checkpointing,
|
||||||
|
|
||||||
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
|
||||||
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||||
|
@ -282,8 +284,6 @@ class Base(nn.Module):
|
||||||
y: sampled tokens
|
y: sampled tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch_size = len(text_list)
|
|
||||||
|
|
||||||
x_list = self._samplewise_merge_tensors(
|
x_list = self._samplewise_merge_tensors(
|
||||||
self.text_emb(text_list),
|
self.text_emb(text_list),
|
||||||
self.proms_emb(proms_list),
|
self.proms_emb(proms_list),
|
||||||
|
@ -292,14 +292,15 @@ class Base(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
x, m = list_to_tensor(x_list)
|
x, m = list_to_tensor(x_list)
|
||||||
|
|
||||||
|
batch_size = len(text_list)
|
||||||
device = x.device
|
device = x.device
|
||||||
|
|
||||||
if state is not None:
|
if state is not None:
|
||||||
# prefill
|
# prefill
|
||||||
prefill_size = x.shape[1]
|
|
||||||
|
|
||||||
# run the initial prompt to fill the KV cache
|
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
|
prefill_size = x.shape[1]
|
||||||
|
# run the initial prompt to fill the KV cache
|
||||||
for n in range(prefill_size):
|
for n in range(prefill_size):
|
||||||
xi = x[:, n, :].unsqueeze(1)
|
xi = x[:, n, :].unsqueeze(1)
|
||||||
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
||||||
|
@ -312,7 +313,6 @@ class Base(nn.Module):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, m, quant_levels)
|
x = block(x, m, quant_levels)
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
# to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward
|
|
||||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||||
|
|
||||||
x = self.classifier(x) * m
|
x = self.classifier(x) * m
|
||||||
|
@ -327,33 +327,28 @@ class Base(nn.Module):
|
||||||
|
|
||||||
ignore_sep = torch.tensor(self.ignore_index, device=device)
|
ignore_sep = torch.tensor(self.ignore_index, device=device)
|
||||||
|
|
||||||
# ignore the prompt when computing loss
|
# create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
|
||||||
prom_list = [
|
prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ]
|
||||||
torch.full_like(t[..., 0], self.ignore_index) for t in proms_list
|
# remake input sequence
|
||||||
]
|
text_prom_list = self._samplewise_merge_tensors( text_list, prom_list, sep=ignore_sep )
|
||||||
# remake input with ignored input prompt
|
|
||||||
text_prom_list = self._samplewise_merge_tensors(
|
|
||||||
text_list, prom_list, sep=ignore_sep
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# process each batch
|
||||||
for i in range(len(text_prom_list)):
|
for i in range(len(text_prom_list)):
|
||||||
# ignore computing loss against text/prompt portion of input
|
# for the NAR, ignore completely computing the loss against the text prompt
|
||||||
# the NAR doesn't need to compute the loss for it
|
|
||||||
if self.resp_loss_only:
|
if self.resp_loss_only:
|
||||||
text_prom_list[i][:] = self.ignore_index
|
text_prom_list[i][:] = self.ignore_index
|
||||||
|
|
||||||
# roll the text/prompt for loss computing
|
# for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token
|
||||||
# the AR benefits from this, for some reason I'll figure out later
|
|
||||||
else:
|
else:
|
||||||
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
|
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
|
||||||
text_prom_list[i][-1] = self.ignore_index
|
text_prom_list[i][-1] = self.ignore_index
|
||||||
|
|
||||||
# for the AR, roll by one and mark the ending with a stop token
|
# adjust the target sequence if needed for the AR
|
||||||
# this coerces the model into properly inferencing causally
|
|
||||||
|
|
||||||
# why we don't just append a stop token in the dataloader, who knows
|
|
||||||
if shift_targ_list:
|
if shift_targ_list:
|
||||||
|
# creates a copy because this is aliased against input response sequence
|
||||||
targ_list = [*targ_list]
|
targ_list = [*targ_list]
|
||||||
|
# shift the target response into the future by 1, and mark the rolled back token / last token as a stop token
|
||||||
|
# this prepares the AR to actually generate autoregressive sequences
|
||||||
for i in range(len(targ_list)):
|
for i in range(len(targ_list)):
|
||||||
targ_list[i] = targ_list[i].roll(-1, dims=0)
|
targ_list[i] = targ_list[i].roll(-1, dims=0)
|
||||||
targ_list[i][-1] = self.stop_token
|
targ_list[i][-1] = self.stop_token
|
||||||
|
@ -362,6 +357,7 @@ class Base(nn.Module):
|
||||||
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
|
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
|
||||||
|
|
||||||
self.loss = dict(
|
self.loss = dict(
|
||||||
|
# "nll" was in the original implementation and should actually just be called something else
|
||||||
nll=F.cross_entropy(
|
nll=F.cross_entropy(
|
||||||
torch.cat(h_list), # input / predicted logits
|
torch.cat(h_list), # input / predicted logits
|
||||||
torch.cat(y_list), # target / ground truth
|
torch.cat(y_list), # target / ground truth
|
||||||
|
|
|
@ -153,9 +153,10 @@ class PrenormResidual(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Sequential):
|
class Block(nn.Sequential):
|
||||||
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels):
|
def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels, activation_checkpointing=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.activation_checkpointing = activation_checkpointing
|
||||||
self.attn = PrenormResidual(
|
self.attn = PrenormResidual(
|
||||||
Attention(d_model, n_heads, causal),
|
Attention(d_model, n_heads, causal),
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
|
@ -186,8 +187,7 @@ class Block(nn.Sequential):
|
||||||
m: (b t 1)
|
m: (b t 1)
|
||||||
l: (b)
|
l: (b)
|
||||||
"""
|
"""
|
||||||
poor_in_vram = True
|
if x.requires_grad and self.activation_checkpointing:
|
||||||
if x.requires_grad and poor_in_vram:
|
|
||||||
x = checkpoint(self.attn, x, m, l, use_reentrant=False)
|
x = checkpoint(self.attn, x, m, l, use_reentrant=False)
|
||||||
else:
|
else:
|
||||||
x = self.attn(x, m, l)
|
x = self.attn(x, m, l)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user