forked from mrq/DL-Art-School
m2v grad norm groups
This commit is contained in:
parent
c1bdb4f9a1
commit
c37fc3b4ed
|
@ -579,6 +579,16 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
self.min_gumbel_temperature,
|
self.min_gumbel_temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
if self.freeze_main_net:
|
||||||
|
return {}
|
||||||
|
groups = {
|
||||||
|
'projector': list(self.m2v.input_blocks.parameters()) + list(self.m2v.projector.parameters()),
|
||||||
|
'encoder': list(self.m2v.encoder.parameters()),
|
||||||
|
'output_blocks': list(self.project_hid.parameters()) + list(self.project_q.parameters()),
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
|
||||||
def forward(self, mel):
|
def forward(self, mel):
|
||||||
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
|
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user