Update diffusion_noise_surfer to support audio

This commit is contained in:
James Betker 2021-09-01 08:34:47 -06:00
parent 3e073cff85
commit 92e7e57f81
3 changed files with 59 additions and 33 deletions

View File

@ -8,11 +8,13 @@ from collections import OrderedDict
import numpy import numpy
from PIL import Image from PIL import Image
from scipy.io import wavfile
from torchvision.transforms import ToTensor from torchvision.transforms import ToTensor
import utils import utils
import utils.options as option import utils.options as option
import utils.util as util import utils.util as util
from data.audio.wavfile_dataset import load_audio_from_wav
from trainer.ExtensibleTrainer import ExtensibleTrainer from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from tqdm import tqdm from tqdm import tqdm
@ -21,7 +23,7 @@ import numpy as np
# A rough copy of test.py that "surfs" along a set of random noise priors to show the affect of gaussian noise on the results. # A rough copy of test.py that "surfs" along a set of random noise priors to show the affect of gaussian noise on the results.
def forward_pass(model, data, output_dir, spacing): def forward_pass(model, data, output_dir, spacing, audio_mode):
with torch.no_grad(): with torch.no_grad():
model.feed_data(data, 0) model.feed_data(data, 0)
model.test() model.test()
@ -29,13 +31,17 @@ def forward_pass(model, data, output_dir, spacing):
visuals = model.get_current_visuals()['rlt'].cpu() visuals = model.get_current_visuals()['rlt'].cpu()
img_path = data['GT_path'][0] img_path = data['GT_path'][0]
img_name = osp.splitext(osp.basename(img_path))[0] img_name = osp.splitext(osp.basename(img_path))[0]
sr_img = util.tensor2img(visuals[0]) # uint8 sr_img = visuals[0]
# save images # save images
suffixes = [f'_{int(spacing)}'] suffixes = [f'_{int(spacing)}']
for suffix in suffixes: for suffix in suffixes:
if audio_mode:
save_img_path = osp.join(output_dir, img_name + suffix + '.wav')
wavfile.write(osp.join(output_dir, save_img_path), 22050, sr_img[0].cpu().numpy())
else:
save_img_path = osp.join(output_dir, img_name + suffix + '.png') save_img_path = osp.join(output_dir, img_name + suffix + '.png')
util.save_img(sr_img, save_img_path) util.save_img(util.tensor2img(sr_img), save_img_path)
if __name__ == "__main__": if __name__ == "__main__":
@ -45,10 +51,11 @@ if __name__ == "__main__":
np.random.seed(5555) np.random.seed(5555)
#### options #### options
audio_mode = True # Whether to render audio or images.
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
want_metrics = False want_metrics = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt utils.util.loaded_options = opt
@ -62,6 +69,12 @@ if __name__ == "__main__":
logger.info(option.dict2str(opt)) logger.info(option.dict2str(opt))
# Load test image # Load test image
if audio_mode:
im = load_audio_from_wav(opt['image'], opt['sample_rate'])
im = im[:, :(im.shape[1]//4096)*4096]
# Hack to reduce memory usage (but cuts off sample):
im = im[:, :40960]
else:
im = ToTensor()(Image.open(opt['image'])) * 2 - 1 im = ToTensor()(Image.open(opt['image'])) * 2 - 1
_, h, w = im.shape _, h, w = im.shape
if h % 2 == 1: if h % 2 == 1:
@ -80,15 +93,21 @@ if __name__ == "__main__":
# Build the corruption indexes we are going to use. # Build the corruption indexes we are going to use.
correction_factors = opt['correction_factor'] correction_factors = opt['correction_factor']
opt['steps']['generator']['injectors']['visual_debug']['zero_noise'] = False #opt['steps']['generator']['injectors']['visual_debug']['zero_noise'] = False
model = ExtensibleTrainer(opt) model = ExtensibleTrainer(opt)
results_dir = osp.join(opt['path']['results_root'], os.path.basename(opt['image'])) results_dir = osp.join(opt['path']['results_root'], os.path.basename(opt['image']))
util.mkdir(results_dir) util.mkdir(results_dir)
for i in range(10): for i in range(10):
if audio_mode:
data = {
'clip': im.to('cuda'),
'GT_path': opt['image']
}
else:
data = { data = {
'hq': im.to('cuda'), 'hq': im.to('cuda'),
'corruption_entropy': torch.tensor([correction_factors], device='cuda', 'corruption_entropy': torch.tensor([correction_factors], device='cuda',
dtype=torch.float), dtype=torch.float),
'GT_path': opt['image'] 'GT_path': opt['image']
} }
forward_pass(model, data, results_dir, i) forward_pass(model, data, results_dir, i, audio_mode)

View File

@ -97,14 +97,21 @@ class GaussianDiffusionInferenceInjector(Injector):
model_inputs = {k: state[v][:self.output_batch_size] for k, v in self.model_input_keys.items()} model_inputs = {k: state[v][:self.output_batch_size] for k, v in self.model_input_keys.items()}
gen.eval() gen.eval()
with torch.no_grad(): with torch.no_grad():
if 'low_res' in model_inputs.keys():
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor, output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
model_inputs['low_res'].shape[-1] * self.output_scale_factor) model_inputs['low_res'].shape[-1] * self.output_scale_factor)
dev = model_inputs['low_res'].device
elif 'spectrogram' in model_inputs.keys():
output_shape = (self.output_batch_size, 1, model_inputs['spectrogram'].shape[-1]*256)
dev = model_inputs['spectrogram'].device
else:
raise NotImplementedError
noise = None noise = None
if self.noise_style == 'zero': if self.noise_style == 'zero':
noise = torch.zeros(output_shape, device=model_inputs['low_res'].device) noise = torch.zeros(output_shape, device=dev)
elif self.noise_style == 'fixed': elif self.noise_style == 'fixed':
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape: if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
self.fixed_noise = torch.randn(output_shape, device=model_inputs['low_res'].device) self.fixed_noise = torch.randn(output_shape, device=dev)
noise = self.fixed_noise noise = self.fixed_noise
gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True) gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True)
if self.undo_n1_to_1: if self.undo_n1_to_1: