fix2
This commit is contained in:
parent
aac92b01b3
commit
f9ebcf11d8
|
@ -1,3 +1,5 @@
|
||||||
|
import itertools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -224,8 +226,8 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
groups = {
|
groups = {
|
||||||
'attention_layers': [lyr.attn.parameters() for lyr in self.diff.layers],
|
'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])),
|
||||||
'ff_layers': [lyr.ff.parameters() for lyr in self.diff.layers],
|
'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])),
|
||||||
'quantizer_encoder': list(self.quantizer.encoder.parameters()),
|
'quantizer_encoder': list(self.quantizer.encoder.parameters()),
|
||||||
'quant_codebook': [self.quantizer.quantizer.codevectors],
|
'quant_codebook': [self.quantizer.quantizer.codevectors],
|
||||||
'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),
|
'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user