diff --git a/codes/models/tacotron2/loss.py b/codes/models/tacotron2/loss.py index 3109e317..77d6457b 100644 --- a/codes/models/tacotron2/loss.py +++ b/codes/models/tacotron2/loss.py @@ -11,6 +11,8 @@ class Tacotron2Loss(ConfigurableLoss): self.mel_output_postnet_key = opt_loss['mel_output_postnet_key'] self.gate_target_key = opt_loss['gate_target_key'] self.gate_output_key = opt_loss['gate_output_key'] + self.last_mel_loss = 0 + self.last_gate_loss = 0 def forward(self, _, state): mel_target, gate_target = state[self.mel_target_key], state[self.gate_target_key] @@ -23,12 +25,22 @@ class Tacotron2Loss(ConfigurableLoss): mel_loss = nn.MSELoss()(mel_out, mel_target) + \ nn.MSELoss()(mel_out_postnet, mel_target) gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) + self.last_mel_loss = mel_loss.detach().clone().mean().item() + self.last_gate_loss = gate_loss.detach().clone().mean().item() return mel_loss + gate_loss + def extra_metrics(self): + return { + 'mel_loss': self.last_mel_loss, + 'gate_loss': self.last_gate_loss + } + class Tacotron2LossRaw(nn.Module): def __init__(self): super().__init__() + self.last_mel_loss = 0 + self.last_gate_loss = 0 def forward(self, model_output, targets): mel_target, gate_target = targets[0], targets[1] @@ -41,4 +53,12 @@ class Tacotron2LossRaw(nn.Module): mel_loss = nn.MSELoss()(mel_out, mel_target) + \ nn.MSELoss()(mel_out_postnet, mel_target) gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) + self.last_mel_loss = mel_loss.detach().clone().mean().item() + self.last_gate_loss = gate_loss.detach().clone().mean().item() return mel_loss + gate_loss + + def extra_metrics(self): + return { + 'mel_loss': self.last_mel_loss, + 'gate_loss': self.last_gate_loss + } \ No newline at end of file diff --git a/codes/trainer/networks.py b/codes/trainer/networks.py index cc1104a3..e0733ffe 100644 --- a/codes/trainer/networks.py +++ b/codes/trainer/networks.py @@ -1,5 +1,6 @@ import importlib import logging +import os import pkgutil import sys from collections import OrderedDict @@ -36,6 +37,8 @@ def find_registered_model_fns(base_path='models'): found_fns = {} module_iter = pkgutil.walk_packages([base_path]) for mod in module_iter: + if os.getcwd() not in mod.module_finder.path: + continue # I have no idea why this is necessary - I think it's a bug in the latest PyWindows release. if mod.ispkg: EXCLUSION_LIST = ['flownet2'] if mod.name not in EXCLUSION_LIST: