This commit is contained in:
James Betker 2022-03-21 17:22:35 -06:00
parent cc4c9faf9a
commit e47a759ed8

View File

@ -57,11 +57,13 @@ class CheckpointedXTransformerEncoder(nn.Module):
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
to channels-last that XTransformer expects. to channels-last that XTransformer expects.
""" """
def __init__(self, needs_permute=True, **xtransformer_kwargs): def __init__(self, needs_permute=True, checkpoint=True, **xtransformer_kwargs):
super().__init__() super().__init__()
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
self.needs_permute = needs_permute self.needs_permute = needs_permute
if not checkpoint:
return
for i in range(len(self.transformer.attn_layers.layers)): for i in range(len(self.transformer.attn_layers.layers)):
n, b, r = self.transformer.attn_layers.layers[i] n, b, r = self.transformer.attn_layers.layers[i]
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])