This commit is contained in:
James Betker 2022-05-28 22:32:38 -06:00
parent da367da411
commit 536c8558ae

View File

@ -136,7 +136,7 @@ class TransformerDiffusion(nn.Module):
groups = {
'contextual_embedder': list(self.conditioning_embedder.parameters()),
'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()),
'code_converters': list(self.embeddings.parameters()) + list(self.code_converter.parameters()),
'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()),
'time_embed': list(self.time_embed.parameters()),
}
return groups