This commit is contained in:
James Betker 2022-06-08 11:54:46 -06:00
parent dee2b72786
commit 91be38cba3

View File

@ -27,7 +27,7 @@ class ConditioningEncoder(nn.Module):
def forward(self, x):
h = checkpoint(self.init, x)
h = self.attn(h
h = self.attn(h)
return h.mean(dim=2)