DL-Art-School/codes/trainer/eval/mel_evaluator.py
2021-07-08 22:13:44 -06:00

45 lines
1.6 KiB
Python

import torch
import trainer.eval.evaluator as evaluator
from data import create_dataset
from data.audio.nv_tacotron_dataset import TextMelCollate
from models.tacotron2.loss import Tacotron2LossRaw
from torch.utils.data import DataLoader
from tqdm import tqdm
# Evaluates the performance of a MEL spectrogram predictor.
class MelEvaluator(evaluator.Evaluator):
def __init__(self, model, opt_eval, env):
super().__init__(model, opt_eval, env, uses_all_ddp=True)
self.batch_sz = opt_eval['batch_size']
self.dataset = create_dataset(opt_eval['dataset'])
assert self.batch_sz is not None
self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1, collate_fn=TextMelCollate(n_frames_per_step=1))
self.criterion = Tacotron2LossRaw()
def perform_eval(self):
counter = 0
total_error = 0
self.model.eval()
for batch in tqdm(self.dataloader):
model_params = {
'text_inputs': 'padded_text',
'text_lengths': 'input_lengths',
'mels': 'padded_mel',
'output_lengths': 'output_lengths',
}
params = {k: batch[v].to(self.env['device']) for k, v in model_params.items()}
with torch.no_grad():
pred = self.model(**params)
targets = ['padded_mel', 'padded_gate']
targets = [batch[t].to(self.env['device']) for t in targets]
total_error += self.criterion(pred, targets).item()
counter += 1
self.model.train()
return {"validation-score": total_error / counter}