tacotron2 work

This commit is contained in:
James Betker 2021-07-14 21:41:57 -06:00
parent fe0c699ced
commit 5584cfcc7a
2 changed files with 23 additions and 0 deletions

View File

@ -11,6 +11,8 @@ class Tacotron2Loss(ConfigurableLoss):
self.mel_output_postnet_key = opt_loss['mel_output_postnet_key'] self.mel_output_postnet_key = opt_loss['mel_output_postnet_key']
self.gate_target_key = opt_loss['gate_target_key'] self.gate_target_key = opt_loss['gate_target_key']
self.gate_output_key = opt_loss['gate_output_key'] self.gate_output_key = opt_loss['gate_output_key']
self.last_mel_loss = 0
self.last_gate_loss = 0
def forward(self, _, state): def forward(self, _, state):
mel_target, gate_target = state[self.mel_target_key], state[self.gate_target_key] mel_target, gate_target = state[self.mel_target_key], state[self.gate_target_key]
@ -23,12 +25,22 @@ class Tacotron2Loss(ConfigurableLoss):
mel_loss = nn.MSELoss()(mel_out, mel_target) + \ mel_loss = nn.MSELoss()(mel_out, mel_target) + \
nn.MSELoss()(mel_out_postnet, mel_target) nn.MSELoss()(mel_out_postnet, mel_target)
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
self.last_mel_loss = mel_loss.detach().clone().mean().item()
self.last_gate_loss = gate_loss.detach().clone().mean().item()
return mel_loss + gate_loss return mel_loss + gate_loss
def extra_metrics(self):
return {
'mel_loss': self.last_mel_loss,
'gate_loss': self.last_gate_loss
}
class Tacotron2LossRaw(nn.Module): class Tacotron2LossRaw(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.last_mel_loss = 0
self.last_gate_loss = 0
def forward(self, model_output, targets): def forward(self, model_output, targets):
mel_target, gate_target = targets[0], targets[1] mel_target, gate_target = targets[0], targets[1]
@ -41,4 +53,12 @@ class Tacotron2LossRaw(nn.Module):
mel_loss = nn.MSELoss()(mel_out, mel_target) + \ mel_loss = nn.MSELoss()(mel_out, mel_target) + \
nn.MSELoss()(mel_out_postnet, mel_target) nn.MSELoss()(mel_out_postnet, mel_target)
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
self.last_mel_loss = mel_loss.detach().clone().mean().item()
self.last_gate_loss = gate_loss.detach().clone().mean().item()
return mel_loss + gate_loss return mel_loss + gate_loss
def extra_metrics(self):
return {
'mel_loss': self.last_mel_loss,
'gate_loss': self.last_gate_loss
}

View File

@ -1,5 +1,6 @@
import importlib import importlib
import logging import logging
import os
import pkgutil import pkgutil
import sys import sys
from collections import OrderedDict from collections import OrderedDict
@ -36,6 +37,8 @@ def find_registered_model_fns(base_path='models'):
found_fns = {} found_fns = {}
module_iter = pkgutil.walk_packages([base_path]) module_iter = pkgutil.walk_packages([base_path])
for mod in module_iter: for mod in module_iter:
if os.getcwd() not in mod.module_finder.path:
continue # I have no idea why this is necessary - I think it's a bug in the latest PyWindows release.
if mod.ispkg: if mod.ispkg:
EXCLUSION_LIST = ['flownet2'] EXCLUSION_LIST = ['flownet2']
if mod.name not in EXCLUSION_LIST: if mod.name not in EXCLUSION_LIST: