m2v grad norm groups

This commit is contained in:
James Betker 2022-05-17 16:29:36 -06:00
parent c1bdb4f9a1
commit c37fc3b4ed

View File

@ -579,6 +579,16 @@ class ContrastiveTrainingWrapper(nn.Module):
self.min_gumbel_temperature,
)
def get_grad_norm_parameter_groups(self):
if self.freeze_main_net:
return {}
groups = {
'projector': list(self.m2v.input_blocks.parameters()) + list(self.m2v.projector.parameters()),
'encoder': list(self.m2v.encoder.parameters()),
'output_blocks': list(self.project_hid.parameters()) + list(self.project_q.parameters()),
}
return groups
def forward(self, mel):
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.