This commit is contained in:
James Betker 2022-06-05 01:31:37 -06:00
parent aac92b01b3
commit f9ebcf11d8

View File

@ -1,3 +1,5 @@
import itertools
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -224,8 +226,8 @@ class TransformerDiffusionWithQuantizer(nn.Module):
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
groups = { groups = {
'attention_layers': [lyr.attn.parameters() for lyr in self.diff.layers], 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])),
'ff_layers': [lyr.ff.parameters() for lyr in self.diff.layers], 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])),
'quantizer_encoder': list(self.quantizer.encoder.parameters()), 'quantizer_encoder': list(self.quantizer.encoder.parameters()),
'quant_codebook': [self.quantizer.quantizer.codevectors], 'quant_codebook': [self.quantizer.quantizer.codevectors],
'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),