diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 2ca71877..c8c4c077 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -579,6 +579,16 @@ class ContrastiveTrainingWrapper(nn.Module): self.min_gumbel_temperature, ) + def get_grad_norm_parameter_groups(self): + if self.freeze_main_net: + return {} + groups = { + 'projector': list(self.m2v.input_blocks.parameters()) + list(self.m2v.projector.parameters()), + 'encoder': list(self.m2v.encoder.parameters()), + 'output_blocks': list(self.project_hid.parameters()) + list(self.project_q.parameters()), + } + return groups + def forward(self, mel): mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.