.......
This commit is contained in:
parent
cc4c9faf9a
commit
e47a759ed8
|
@ -57,11 +57,13 @@ class CheckpointedXTransformerEncoder(nn.Module):
|
||||||
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
||||||
to channels-last that XTransformer expects.
|
to channels-last that XTransformer expects.
|
||||||
"""
|
"""
|
||||||
def __init__(self, needs_permute=True, **xtransformer_kwargs):
|
def __init__(self, needs_permute=True, checkpoint=True, **xtransformer_kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
||||||
self.needs_permute = needs_permute
|
self.needs_permute = needs_permute
|
||||||
|
|
||||||
|
if not checkpoint:
|
||||||
|
return
|
||||||
for i in range(len(self.transformer.attn_layers.layers)):
|
for i in range(len(self.transformer.attn_layers.layers)):
|
||||||
n, b, r = self.transformer.attn_layers.layers[i]
|
n, b, r = self.transformer.attn_layers.layers[i]
|
||||||
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user