forked from mrq/DL-Art-School
m2v grad norm groups
This commit is contained in:
parent
c1bdb4f9a1
commit
c37fc3b4ed
|
@ -579,6 +579,16 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
self.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
if self.freeze_main_net:
|
||||
return {}
|
||||
groups = {
|
||||
'projector': list(self.m2v.input_blocks.parameters()) + list(self.m2v.projector.parameters()),
|
||||
'encoder': list(self.m2v.encoder.parameters()),
|
||||
'output_blocks': list(self.project_hid.parameters()) + list(self.project_q.parameters()),
|
||||
}
|
||||
return groups
|
||||
|
||||
def forward(self, mel):
|
||||
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user