Add waveglow & inference capabilities to audio generator
This commit is contained in:
parent
3febe6cbf4
commit
be2745f42d
0
codes/models/waveglow/__init__.py
Normal file
0
codes/models/waveglow/__init__.py
Normal file
42
codes/models/waveglow/denoiser.py
Normal file
42
codes/models/waveglow/denoiser.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from models.tacotron2.stft import STFT
|
||||||
|
|
||||||
|
sys.path.append('tacotron2')
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Denoiser(torch.nn.Module):
|
||||||
|
""" Removes model bias from audio produced with waveglow """
|
||||||
|
|
||||||
|
def __init__(self, waveglow, filter_length=1024, n_overlap=4,
|
||||||
|
win_length=1024, mode='zeros'):
|
||||||
|
super(Denoiser, self).__init__()
|
||||||
|
self.stft = STFT(filter_length=filter_length,
|
||||||
|
hop_length=int(filter_length/n_overlap),
|
||||||
|
win_length=win_length).cuda()
|
||||||
|
if mode == 'zeros':
|
||||||
|
mel_input = torch.zeros(
|
||||||
|
(1, 80, 88),
|
||||||
|
dtype=waveglow.upsample.weight.dtype,
|
||||||
|
device=waveglow.upsample.weight.device)
|
||||||
|
elif mode == 'normal':
|
||||||
|
mel_input = torch.randn(
|
||||||
|
(1, 80, 88),
|
||||||
|
dtype=waveglow.upsample.weight.dtype,
|
||||||
|
device=waveglow.upsample.weight.device)
|
||||||
|
else:
|
||||||
|
raise Exception("Mode {} if not supported".format(mode))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
bias_audio = waveglow.infer(mel_input, sigma=0.0).float()
|
||||||
|
bias_spec, _ = self.stft.transform(bias_audio)
|
||||||
|
|
||||||
|
self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
|
||||||
|
|
||||||
|
def forward(self, audio, strength=0.1):
|
||||||
|
audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
|
||||||
|
audio_spec_denoised = audio_spec - self.bias_spec * strength
|
||||||
|
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
|
||||||
|
audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
|
||||||
|
return audio_denoised
|
318
codes/models/waveglow/waveglow.py
Normal file
318
codes/models/waveglow/waveglow.py
Normal file
|
@ -0,0 +1,318 @@
|
||||||
|
# *****************************************************************************
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
# * Redistributions of source code must retain the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer.
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer in the
|
||||||
|
# documentation and/or other materials provided with the distribution.
|
||||||
|
# * Neither the name of the NVIDIA CORPORATION nor the
|
||||||
|
# names of its contributors may be used to endorse or promote products
|
||||||
|
# derived from this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||||
|
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
#
|
||||||
|
# *****************************************************************************
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from trainer.networks import register_model
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||||
|
n_channels_int = n_channels[0]
|
||||||
|
in_act = input_a+input_b
|
||||||
|
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||||
|
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||||
|
acts = t_act * s_act
|
||||||
|
return acts
|
||||||
|
|
||||||
|
|
||||||
|
class WaveGlowLoss(torch.nn.Module):
|
||||||
|
def __init__(self, sigma=1.0):
|
||||||
|
super(WaveGlowLoss, self).__init__()
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
def forward(self, model_output):
|
||||||
|
z, log_s_list, log_det_W_list = model_output
|
||||||
|
for i, log_s in enumerate(log_s_list):
|
||||||
|
if i == 0:
|
||||||
|
log_s_total = torch.sum(log_s)
|
||||||
|
log_det_W_total = log_det_W_list[i]
|
||||||
|
else:
|
||||||
|
log_s_total = log_s_total + torch.sum(log_s)
|
||||||
|
log_det_W_total += log_det_W_list[i]
|
||||||
|
|
||||||
|
loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
|
||||||
|
return loss/(z.size(0)*z.size(1)*z.size(2))
|
||||||
|
|
||||||
|
|
||||||
|
class Invertible1x1Conv(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
The layer outputs both the convolution, and the log determinant
|
||||||
|
of its weight matrix. If reverse=True it does convolution with
|
||||||
|
inverse
|
||||||
|
"""
|
||||||
|
def __init__(self, c):
|
||||||
|
super(Invertible1x1Conv, self).__init__()
|
||||||
|
self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
# Sample a random orthonormal matrix to initialize weights
|
||||||
|
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
|
||||||
|
|
||||||
|
# Ensure determinant is 1.0 not -1.0
|
||||||
|
if torch.det(W) < 0:
|
||||||
|
W[:,0] = -1*W[:,0]
|
||||||
|
W = W.view(c, c, 1)
|
||||||
|
self.conv.weight.data = W
|
||||||
|
|
||||||
|
def forward(self, z, reverse=False):
|
||||||
|
# shape
|
||||||
|
batch_size, group_size, n_of_groups = z.size()
|
||||||
|
|
||||||
|
W = self.conv.weight.squeeze()
|
||||||
|
|
||||||
|
if reverse:
|
||||||
|
if not hasattr(self, 'W_inverse'):
|
||||||
|
# Reverse computation
|
||||||
|
W_inverse = W.float().inverse()
|
||||||
|
W_inverse = Variable(W_inverse[..., None])
|
||||||
|
if z.type() == 'torch.cuda.HalfTensor':
|
||||||
|
W_inverse = W_inverse.half()
|
||||||
|
self.W_inverse = W_inverse
|
||||||
|
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
|
||||||
|
return z
|
||||||
|
else:
|
||||||
|
# Forward computation
|
||||||
|
log_det_W = batch_size * n_of_groups * torch.logdet(W)
|
||||||
|
z = self.conv(z)
|
||||||
|
return z, log_det_W
|
||||||
|
|
||||||
|
|
||||||
|
class WN(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This is the WaveNet like layer for the affine coupling. The primary difference
|
||||||
|
from WaveNet is the convolutions need not be causal. There is also no dilation
|
||||||
|
size reset. The dilation only doubles on each layer
|
||||||
|
"""
|
||||||
|
def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
|
||||||
|
kernel_size):
|
||||||
|
super(WN, self).__init__()
|
||||||
|
assert(kernel_size % 2 == 1)
|
||||||
|
assert(n_channels % 2 == 0)
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.n_channels = n_channels
|
||||||
|
self.in_layers = torch.nn.ModuleList()
|
||||||
|
self.res_skip_layers = torch.nn.ModuleList()
|
||||||
|
|
||||||
|
start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
|
||||||
|
start = torch.nn.utils.weight_norm(start, name='weight')
|
||||||
|
self.start = start
|
||||||
|
|
||||||
|
# Initializing last layer to 0 makes the affine coupling layers
|
||||||
|
# do nothing at first. This helps with training stability
|
||||||
|
end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
|
||||||
|
end.weight.data.zero_()
|
||||||
|
end.bias.data.zero_()
|
||||||
|
self.end = end
|
||||||
|
|
||||||
|
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
|
||||||
|
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
||||||
|
|
||||||
|
for i in range(n_layers):
|
||||||
|
dilation = 2 ** i
|
||||||
|
padding = int((kernel_size*dilation - dilation)/2)
|
||||||
|
in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
|
||||||
|
dilation=dilation, padding=padding)
|
||||||
|
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
||||||
|
self.in_layers.append(in_layer)
|
||||||
|
|
||||||
|
|
||||||
|
# last one is not necessary
|
||||||
|
if i < n_layers - 1:
|
||||||
|
res_skip_channels = 2*n_channels
|
||||||
|
else:
|
||||||
|
res_skip_channels = n_channels
|
||||||
|
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
|
||||||
|
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
||||||
|
self.res_skip_layers.append(res_skip_layer)
|
||||||
|
|
||||||
|
def forward(self, forward_input):
|
||||||
|
audio, spect = forward_input
|
||||||
|
audio = self.start(audio)
|
||||||
|
output = torch.zeros_like(audio)
|
||||||
|
n_channels_tensor = torch.IntTensor([self.n_channels])
|
||||||
|
|
||||||
|
spect = self.cond_layer(spect)
|
||||||
|
|
||||||
|
for i in range(self.n_layers):
|
||||||
|
spect_offset = i*2*self.n_channels
|
||||||
|
acts = fused_add_tanh_sigmoid_multiply(
|
||||||
|
self.in_layers[i](audio),
|
||||||
|
spect[:,spect_offset:spect_offset+2*self.n_channels,:],
|
||||||
|
n_channels_tensor)
|
||||||
|
|
||||||
|
res_skip_acts = self.res_skip_layers[i](acts)
|
||||||
|
if i < self.n_layers - 1:
|
||||||
|
audio = audio + res_skip_acts[:,:self.n_channels,:]
|
||||||
|
output = output + res_skip_acts[:,self.n_channels:,:]
|
||||||
|
else:
|
||||||
|
output = output + res_skip_acts
|
||||||
|
|
||||||
|
return self.end(output)
|
||||||
|
|
||||||
|
|
||||||
|
class WaveGlow(torch.nn.Module):
|
||||||
|
def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
|
||||||
|
n_early_size, WN_config):
|
||||||
|
super(WaveGlow, self).__init__()
|
||||||
|
|
||||||
|
self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
|
||||||
|
n_mel_channels,
|
||||||
|
1024, stride=256)
|
||||||
|
assert(n_group % 2 == 0)
|
||||||
|
self.n_flows = n_flows
|
||||||
|
self.n_group = n_group
|
||||||
|
self.n_early_every = n_early_every
|
||||||
|
self.n_early_size = n_early_size
|
||||||
|
self.WN = torch.nn.ModuleList()
|
||||||
|
self.convinv = torch.nn.ModuleList()
|
||||||
|
|
||||||
|
n_half = int(n_group/2)
|
||||||
|
|
||||||
|
# Set up layers with the right sizes based on how many dimensions
|
||||||
|
# have been output already
|
||||||
|
n_remaining_channels = n_group
|
||||||
|
for k in range(n_flows):
|
||||||
|
if k % self.n_early_every == 0 and k > 0:
|
||||||
|
n_half = n_half - int(self.n_early_size/2)
|
||||||
|
n_remaining_channels = n_remaining_channels - self.n_early_size
|
||||||
|
self.convinv.append(Invertible1x1Conv(n_remaining_channels))
|
||||||
|
self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
|
||||||
|
self.n_remaining_channels = n_remaining_channels # Useful during inference
|
||||||
|
|
||||||
|
def forward(self, forward_input):
|
||||||
|
"""
|
||||||
|
forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
|
||||||
|
forward_input[1] = audio: batch x time
|
||||||
|
"""
|
||||||
|
spect, audio = forward_input
|
||||||
|
|
||||||
|
# Upsample spectrogram to size of audio
|
||||||
|
spect = self.upsample(spect)
|
||||||
|
assert(spect.size(2) >= audio.size(1))
|
||||||
|
if spect.size(2) > audio.size(1):
|
||||||
|
spect = spect[:, :, :audio.size(1)]
|
||||||
|
|
||||||
|
spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
|
||||||
|
spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
|
||||||
|
|
||||||
|
audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
|
||||||
|
output_audio = []
|
||||||
|
log_s_list = []
|
||||||
|
log_det_W_list = []
|
||||||
|
|
||||||
|
for k in range(self.n_flows):
|
||||||
|
if k % self.n_early_every == 0 and k > 0:
|
||||||
|
output_audio.append(audio[:,:self.n_early_size,:])
|
||||||
|
audio = audio[:,self.n_early_size:,:]
|
||||||
|
|
||||||
|
audio, log_det_W = self.convinv[k](audio)
|
||||||
|
log_det_W_list.append(log_det_W)
|
||||||
|
|
||||||
|
n_half = int(audio.size(1)/2)
|
||||||
|
audio_0 = audio[:,:n_half,:]
|
||||||
|
audio_1 = audio[:,n_half:,:]
|
||||||
|
|
||||||
|
output = self.WN[k]((audio_0, spect))
|
||||||
|
log_s = output[:, n_half:, :]
|
||||||
|
b = output[:, :n_half, :]
|
||||||
|
audio_1 = torch.exp(log_s)*audio_1 + b
|
||||||
|
log_s_list.append(log_s)
|
||||||
|
|
||||||
|
audio = torch.cat([audio_0, audio_1],1)
|
||||||
|
|
||||||
|
output_audio.append(audio)
|
||||||
|
return torch.cat(output_audio,1), log_s_list, log_det_W_list
|
||||||
|
|
||||||
|
def infer(self, spect, sigma=1.0):
|
||||||
|
spect = self.upsample(spect)
|
||||||
|
# trim conv artifacts. maybe pad spec to kernel multiple
|
||||||
|
time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
|
||||||
|
spect = spect[:, :, :-time_cutoff]
|
||||||
|
|
||||||
|
spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
|
||||||
|
spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
|
||||||
|
|
||||||
|
if spect.type() == 'torch.cuda.HalfTensor':
|
||||||
|
audio = torch.cuda.HalfTensor(spect.size(0),
|
||||||
|
self.n_remaining_channels,
|
||||||
|
spect.size(2)).normal_()
|
||||||
|
else:
|
||||||
|
audio = torch.cuda.FloatTensor(spect.size(0),
|
||||||
|
self.n_remaining_channels,
|
||||||
|
spect.size(2)).normal_()
|
||||||
|
|
||||||
|
audio = torch.autograd.Variable(sigma*audio)
|
||||||
|
|
||||||
|
for k in reversed(range(self.n_flows)):
|
||||||
|
n_half = int(audio.size(1)/2)
|
||||||
|
audio_0 = audio[:,:n_half,:]
|
||||||
|
audio_1 = audio[:,n_half:,:]
|
||||||
|
|
||||||
|
output = self.WN[k]((audio_0, spect))
|
||||||
|
|
||||||
|
s = output[:, n_half:, :]
|
||||||
|
b = output[:, :n_half, :]
|
||||||
|
audio_1 = (audio_1 - b)/torch.exp(s)
|
||||||
|
audio = torch.cat([audio_0, audio_1],1)
|
||||||
|
|
||||||
|
audio = self.convinv[k](audio, reverse=True)
|
||||||
|
|
||||||
|
if k % self.n_early_every == 0 and k > 0:
|
||||||
|
if spect.type() == 'torch.cuda.HalfTensor':
|
||||||
|
z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
||||||
|
else:
|
||||||
|
z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
||||||
|
audio = torch.cat((sigma*z, audio),1)
|
||||||
|
|
||||||
|
audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
|
||||||
|
return audio
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def remove_weightnorm(model):
|
||||||
|
waveglow = model
|
||||||
|
for WN in waveglow.WN:
|
||||||
|
WN.start = torch.nn.utils.remove_weight_norm(WN.start)
|
||||||
|
WN.in_layers = remove(WN.in_layers)
|
||||||
|
WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer)
|
||||||
|
WN.res_skip_layers = remove(WN.res_skip_layers)
|
||||||
|
return waveglow
|
||||||
|
|
||||||
|
|
||||||
|
def remove(conv_list):
|
||||||
|
new_conv_list = torch.nn.ModuleList()
|
||||||
|
for old_conv in conv_list:
|
||||||
|
old_conv = torch.nn.utils.remove_weight_norm(old_conv)
|
||||||
|
new_conv_list.append(old_conv)
|
||||||
|
return new_conv_list
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_nv_waveglow(opt_net, opt):
|
||||||
|
return WaveGlow(**opt_net['args'])
|
72
codes/scripts/audio/test_audio_gen.py
Normal file
72
codes/scripts/audio/test_audio_gen.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
import os.path as osp
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import utils
|
||||||
|
import utils.options as option
|
||||||
|
import utils.util as util
|
||||||
|
from models.waveglow.denoiser import Denoiser
|
||||||
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||||
|
from data import create_dataset, create_dataloader
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
|
||||||
|
def forward_pass(model, denoiser, data, output_dir, opt, b):
|
||||||
|
with torch.no_grad():
|
||||||
|
model.feed_data(data, 0)
|
||||||
|
model.test()
|
||||||
|
waveforms = model.eval_state[opt['eval']['output_state']][0]
|
||||||
|
waveforms = denoiser(waveforms)
|
||||||
|
for i in range(waveforms.shape[0]):
|
||||||
|
audio = waveforms[i][0].cpu().numpy()
|
||||||
|
wavfile.write(osp.join(output_dir, f'{b}_{i}.wav'), 22050, audio)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Set seeds
|
||||||
|
torch.manual_seed(5555)
|
||||||
|
random.seed(5555)
|
||||||
|
np.random.seed(5555)
|
||||||
|
|
||||||
|
#### options
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
want_metrics = False
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_tacotron2_lj.yml')
|
||||||
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
utils.util.loaded_options = opt
|
||||||
|
|
||||||
|
util.mkdirs(
|
||||||
|
(path for key, path in opt['path'].items()
|
||||||
|
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
||||||
|
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
||||||
|
screen=True, tofile=True)
|
||||||
|
logger = logging.getLogger('base')
|
||||||
|
logger.info(option.dict2str(opt))
|
||||||
|
|
||||||
|
test_loaders = []
|
||||||
|
for phase, dataset_opt in sorted(opt['datasets'].items()):
|
||||||
|
test_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
||||||
|
test_loader = create_dataloader(test_set, dataset_opt, collate_fn=collate_fn)
|
||||||
|
logger.info('Number of test texts in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
||||||
|
test_loaders.append(test_loader)
|
||||||
|
|
||||||
|
model = ExtensibleTrainer(opt)
|
||||||
|
|
||||||
|
denoiser = Denoiser(model.networks['waveglow'].module) # Pretty hacky, need to figure out a better way to integrate this.
|
||||||
|
|
||||||
|
batch = 0
|
||||||
|
for test_loader in test_loaders:
|
||||||
|
dataset_dir = opt['path']['results_root']
|
||||||
|
util.mkdir(dataset_dir)
|
||||||
|
|
||||||
|
tq = tqdm(test_loader)
|
||||||
|
for data in tq:
|
||||||
|
forward_pass(model, denoiser, data, dataset_dir, opt, batch)
|
||||||
|
batch += 1
|
||||||
|
|
|
@ -16,19 +16,25 @@ class GeneratorInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(GeneratorInjector, self).__init__(opt, env)
|
super(GeneratorInjector, self).__init__(opt, env)
|
||||||
self.grad = opt['grad'] if 'grad' in opt.keys() else True
|
self.grad = opt['grad'] if 'grad' in opt.keys() else True
|
||||||
|
self.method = opt_get(opt, ['method'], None) # If specified, this method is called instead of __call__()
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
|
|
||||||
|
if self.method is not None and hasattr(gen, 'module'):
|
||||||
|
gen = gen.module # Dereference DDP wrapper.
|
||||||
|
method = gen if self.method is None else getattr(gen, self.method)
|
||||||
|
|
||||||
with autocast(enabled=self.env['opt']['fp16']):
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
if isinstance(self.input, list):
|
if isinstance(self.input, list):
|
||||||
params = extract_params_from_state(self.input, state)
|
params = extract_params_from_state(self.input, state)
|
||||||
else:
|
else:
|
||||||
params = [state[self.input]]
|
params = [state[self.input]]
|
||||||
if self.grad:
|
if self.grad:
|
||||||
results = gen(*params)
|
results = method(*params)
|
||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
results = gen(*params)
|
results = method(*params)
|
||||||
new_state = {}
|
new_state = {}
|
||||||
if isinstance(self.output, list):
|
if isinstance(self.output, list):
|
||||||
# Only dereference tuples or lists, not tensors.
|
# Only dereference tuples or lists, not tensors.
|
||||||
|
|
|
@ -393,6 +393,7 @@ def recursively_detach(v):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def opt_get(opt, keys, default=None):
|
def opt_get(opt, keys, default=None):
|
||||||
|
assert not isinstance(keys, str) # Common mistake, better to assert.
|
||||||
if opt is None:
|
if opt is None:
|
||||||
return default
|
return default
|
||||||
ret = opt
|
ret = opt
|
||||||
|
|
Loading…
Reference in New Issue
Block a user