.......
This commit is contained in:
parent
cc4c9faf9a
commit
e47a759ed8
|
@ -57,11 +57,13 @@ class CheckpointedXTransformerEncoder(nn.Module):
|
|||
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
||||
to channels-last that XTransformer expects.
|
||||
"""
|
||||
def __init__(self, needs_permute=True, **xtransformer_kwargs):
|
||||
def __init__(self, needs_permute=True, checkpoint=True, **xtransformer_kwargs):
|
||||
super().__init__()
|
||||
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
||||
self.needs_permute = needs_permute
|
||||
|
||||
if not checkpoint:
|
||||
return
|
||||
for i in range(len(self.transformer.attn_layers.layers)):
|
||||
n, b, r = self.transformer.attn_layers.layers[i]
|
||||
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
||||
|
|
Loading…
Reference in New Issue
Block a user