a
This commit is contained in:
parent
aa653115f1
commit
1dbe0b6b2e
|
@ -151,9 +151,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
groups = {
|
groups = {
|
||||||
'contextual_embedder': list(self.conditioning_embedder.parameters()),
|
'contextual_embedder': list(self.conditioning_embedder.parameters()),
|
||||||
'top_layers': list(self.top_layers.parameters()) + list(self.inp_block.parameters()),
|
'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()),
|
||||||
'mid_layers': list(self.mid_layers.parameters()),
|
|
||||||
'final_layers': list(self.final_layers.parameters()),
|
|
||||||
'code_converters': list(self.embeddings.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()),
|
'code_converters': list(self.embeddings.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()),
|
||||||
'time_embed': list(self.time_embed.parameters()),
|
'time_embed': list(self.time_embed.parameters()),
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user