forked from mrq/DL-Art-School
ffffpt2
This commit is contained in:
parent
95ca88efce
commit
92fe8b4dd9
|
@ -6,9 +6,8 @@ import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
from models.arch_util import checkpoint
|
|
||||||
from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence
|
from models.gpt_voice.reversible import ReversibleSequence, SequentialSequence
|
||||||
from utils.util import sequential_checkpoint
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
@ -210,7 +209,7 @@ class Transformer(nn.Module):
|
||||||
intermediates = []
|
intermediates = []
|
||||||
for attn, ff in self.layers.layers:
|
for attn, ff in self.layers.layers:
|
||||||
x_ff = x + checkpoint(attn, x)
|
x_ff = x + checkpoint(attn, x)
|
||||||
x = x + ff(x_ff)
|
x = x_ff + ff(x_ff)
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
intermediates.append((x_ff, x))
|
intermediates.append((x_ff, x))
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user