forked from mrq/DL-Art-School
tacotron2 work
This commit is contained in:
parent
fe0c699ced
commit
5584cfcc7a
|
@ -11,6 +11,8 @@ class Tacotron2Loss(ConfigurableLoss):
|
||||||
self.mel_output_postnet_key = opt_loss['mel_output_postnet_key']
|
self.mel_output_postnet_key = opt_loss['mel_output_postnet_key']
|
||||||
self.gate_target_key = opt_loss['gate_target_key']
|
self.gate_target_key = opt_loss['gate_target_key']
|
||||||
self.gate_output_key = opt_loss['gate_output_key']
|
self.gate_output_key = opt_loss['gate_output_key']
|
||||||
|
self.last_mel_loss = 0
|
||||||
|
self.last_gate_loss = 0
|
||||||
|
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
mel_target, gate_target = state[self.mel_target_key], state[self.gate_target_key]
|
mel_target, gate_target = state[self.mel_target_key], state[self.gate_target_key]
|
||||||
|
@ -23,12 +25,22 @@ class Tacotron2Loss(ConfigurableLoss):
|
||||||
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
||||||
nn.MSELoss()(mel_out_postnet, mel_target)
|
nn.MSELoss()(mel_out_postnet, mel_target)
|
||||||
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
|
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
|
||||||
|
self.last_mel_loss = mel_loss.detach().clone().mean().item()
|
||||||
|
self.last_gate_loss = gate_loss.detach().clone().mean().item()
|
||||||
return mel_loss + gate_loss
|
return mel_loss + gate_loss
|
||||||
|
|
||||||
|
def extra_metrics(self):
|
||||||
|
return {
|
||||||
|
'mel_loss': self.last_mel_loss,
|
||||||
|
'gate_loss': self.last_gate_loss
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Tacotron2LossRaw(nn.Module):
|
class Tacotron2LossRaw(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.last_mel_loss = 0
|
||||||
|
self.last_gate_loss = 0
|
||||||
|
|
||||||
def forward(self, model_output, targets):
|
def forward(self, model_output, targets):
|
||||||
mel_target, gate_target = targets[0], targets[1]
|
mel_target, gate_target = targets[0], targets[1]
|
||||||
|
@ -41,4 +53,12 @@ class Tacotron2LossRaw(nn.Module):
|
||||||
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
||||||
nn.MSELoss()(mel_out_postnet, mel_target)
|
nn.MSELoss()(mel_out_postnet, mel_target)
|
||||||
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
|
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
|
||||||
|
self.last_mel_loss = mel_loss.detach().clone().mean().item()
|
||||||
|
self.last_gate_loss = gate_loss.detach().clone().mean().item()
|
||||||
return mel_loss + gate_loss
|
return mel_loss + gate_loss
|
||||||
|
|
||||||
|
def extra_metrics(self):
|
||||||
|
return {
|
||||||
|
'mel_loss': self.last_mel_loss,
|
||||||
|
'gate_loss': self.last_gate_loss
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import sys
|
import sys
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
@ -36,6 +37,8 @@ def find_registered_model_fns(base_path='models'):
|
||||||
found_fns = {}
|
found_fns = {}
|
||||||
module_iter = pkgutil.walk_packages([base_path])
|
module_iter = pkgutil.walk_packages([base_path])
|
||||||
for mod in module_iter:
|
for mod in module_iter:
|
||||||
|
if os.getcwd() not in mod.module_finder.path:
|
||||||
|
continue # I have no idea why this is necessary - I think it's a bug in the latest PyWindows release.
|
||||||
if mod.ispkg:
|
if mod.ispkg:
|
||||||
EXCLUSION_LIST = ['flownet2']
|
EXCLUSION_LIST = ['flownet2']
|
||||||
if mod.name not in EXCLUSION_LIST:
|
if mod.name not in EXCLUSION_LIST:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user