This commit is contained in:
James Betker 2022-06-15 09:01:20 -06:00
parent ff5c03b460
commit b51ff8a176

View File

@ -168,7 +168,7 @@ class TransformerDiffusion(nn.Module):
for p in self.parameters(): for p in self.parameters():
p.DO_NOT_TRAIN = True p.DO_NOT_TRAIN = True
p.requires_grad = False p.requires_grad = False
for m in [self.input_converter and self.code_converter]: for m in [self.ar_input and self.ar_prior_intg]:
for p in m.parameters(): for p in m.parameters():
del p.DO_NOT_TRAIN del p.DO_NOT_TRAIN
p.requires_grad = True p.requires_grad = True