From 451726fdd5d6f9fc3c8c3c3990bbca8f0d83bee2 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 5 Sep 2023 15:38:21 -0500 Subject: [PATCH] added ability to disable activation checkpointing through the YAML (it is very VRAM intensive at double layer size) --- vall_e/config.py | 7 ++++++ vall_e/models/__init__.py | 3 ++- vall_e/models/base.py | 42 ++++++++++++++++-------------------- vall_e/models/transformer.py | 6 +++--- 4 files changed, 31 insertions(+), 27 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index c3b69b0..fb441af 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -225,6 +225,11 @@ class Model: return 24 return 12 + @property + def activation_checkpointing(self): + return cfg.trainer.activation_checkpointing + + @dataclass() class Models: _max_levels: int = 0 @@ -420,6 +425,8 @@ class Trainer: load_module_only: bool = False restart_step_count: bool = False + activation_checkpointing: bool = True + aggressive_optimizations: bool = False check_for_oom: bool = True gc_mode: str | None = None diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 16355cc..b6983c1 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -15,7 +15,8 @@ def get_model(cfg): d_model=cfg.dim, n_heads=cfg.heads, n_layers=cfg.layers, - config = cfg + + config = cfg, ) model._cfg = cfg diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 7eefaeb..ebd7018 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -141,10 +141,12 @@ class Base(nn.Module): n_heads: int = 8, n_layers: int = 12, p_dropout: float = 0.1, + config = None, ): super().__init__() self.config = config + self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True self.n_tokens = n_tokens self.d_model = d_model @@ -180,7 +182,7 @@ class Base(nn.Module): decoder_ffn_embed_dim=d_model * 4, decoder_layers=n_layers, dropout=p_dropout, - checkpoint_activations=True, + checkpoint_activations=self.activation_checkpointing, chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0, recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0, @@ -282,8 +284,6 @@ class Base(nn.Module): y: sampled tokens """ - batch_size = len(text_list) - x_list = self._samplewise_merge_tensors( self.text_emb(text_list), self.proms_emb(proms_list), @@ -292,14 +292,15 @@ class Base(nn.Module): ) x, m = list_to_tensor(x_list) + + batch_size = len(text_list) device = x.device if state is not None: # prefill - prefill_size = x.shape[1] - - # run the initial prompt to fill the KV cache if len(state) == 0: + prefill_size = x.shape[1] + # run the initial prompt to fill the KV cache for n in range(prefill_size): xi = x[:, n, :].unsqueeze(1) self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True) @@ -312,7 +313,6 @@ class Base(nn.Module): for block in self.blocks: x = block(x, m, quant_levels) elif self.arch_type == "retnet": - # to-do: actually make this work and verify it works with recurrent_forward / chunkwise_forward x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) x = self.classifier(x) * m @@ -327,33 +327,28 @@ class Base(nn.Module): ignore_sep = torch.tensor(self.ignore_index, device=device) - # ignore the prompt when computing loss - prom_list = [ - torch.full_like(t[..., 0], self.ignore_index) for t in proms_list - ] - # remake input with ignored input prompt - text_prom_list = self._samplewise_merge_tensors( - text_list, prom_list, sep=ignore_sep - ) + # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against + prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ] + # remake input sequence + text_prom_list = self._samplewise_merge_tensors( text_list, prom_list, sep=ignore_sep ) + # process each batch for i in range(len(text_prom_list)): - # ignore computing loss against text/prompt portion of input - # the NAR doesn't need to compute the loss for it + # for the NAR, ignore completely computing the loss against the text prompt if self.resp_loss_only: text_prom_list[i][:] = self.ignore_index - # roll the text/prompt for loss computing - # the AR benefits from this, for some reason I'll figure out later + # for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token else: text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) text_prom_list[i][-1] = self.ignore_index - # for the AR, roll by one and mark the ending with a stop token - # this coerces the model into properly inferencing causally - - # why we don't just append a stop token in the dataloader, who knows + # adjust the target sequence if needed for the AR if shift_targ_list: + # creates a copy because this is aliased against input response sequence targ_list = [*targ_list] + # shift the target response into the future by 1, and mark the rolled back token / last token as a stop token + # this prepares the AR to actually generate autoregressive sequences for i in range(len(targ_list)): targ_list[i] = targ_list[i].roll(-1, dims=0) targ_list[i][-1] = self.stop_token @@ -362,6 +357,7 @@ class Base(nn.Module): y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) self.loss = dict( + # "nll" was in the original implementation and should actually just be called something else nll=F.cross_entropy( torch.cat(h_list), # input / predicted logits torch.cat(y_list), # target / ground truth diff --git a/vall_e/models/transformer.py b/vall_e/models/transformer.py index e65620b..2147839 100755 --- a/vall_e/models/transformer.py +++ b/vall_e/models/transformer.py @@ -153,9 +153,10 @@ class PrenormResidual(nn.Module): class Block(nn.Sequential): - def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels): + def __init__(self, d_model, n_heads, p_dropout, causal, norm_type, n_levels, activation_checkpointing=True): super().__init__() + self.activation_checkpointing = activation_checkpointing self.attn = PrenormResidual( Attention(d_model, n_heads, causal), d_model=d_model, @@ -186,8 +187,7 @@ class Block(nn.Sequential): m: (b t 1) l: (b) """ - poor_in_vram = True - if x.requires_grad and poor_in_vram: + if x.requires_grad and self.activation_checkpointing: x = checkpoint(self.attn, x, m, l, use_reentrant=False) else: x = self.attn(x, m, l)