Update diffusion_noise_surfer to support audio

This commit is contained in:
James Betker 2021-09-01 08:34:47 -06:00
parent 3e073cff85
commit 92e7e57f81
3 changed files with 59 additions and 33 deletions

View File

@ -8,11 +8,13 @@ from collections import OrderedDict
import numpy
from PIL import Image
from scipy.io import wavfile
from torchvision.transforms import ToTensor
import utils
import utils.options as option
import utils.util as util
from data.audio.wavfile_dataset import load_audio_from_wav
from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
from tqdm import tqdm
@ -21,7 +23,7 @@ import numpy as np
# A rough copy of test.py that "surfs" along a set of random noise priors to show the affect of gaussian noise on the results.
def forward_pass(model, data, output_dir, spacing):
def forward_pass(model, data, output_dir, spacing, audio_mode):
with torch.no_grad():
model.feed_data(data, 0)
model.test()
@ -29,13 +31,17 @@ def forward_pass(model, data, output_dir, spacing):
visuals = model.get_current_visuals()['rlt'].cpu()
img_path = data['GT_path'][0]
img_name = osp.splitext(osp.basename(img_path))[0]
sr_img = util.tensor2img(visuals[0]) # uint8
sr_img = visuals[0]
# save images
suffixes = [f'_{int(spacing)}']
for suffix in suffixes:
save_img_path = osp.join(output_dir, img_name + suffix + '.png')
util.save_img(sr_img, save_img_path)
if audio_mode:
save_img_path = osp.join(output_dir, img_name + suffix + '.wav')
wavfile.write(osp.join(output_dir, save_img_path), 22050, sr_img[0].cpu().numpy())
else:
save_img_path = osp.join(output_dir, img_name + suffix + '.png')
util.save_img(util.tensor2img(sr_img), save_img_path)
if __name__ == "__main__":
@ -45,10 +51,11 @@ if __name__ == "__main__":
np.random.seed(5555)
#### options
audio_mode = True # Whether to render audio or images.
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
@ -62,33 +69,45 @@ if __name__ == "__main__":
logger.info(option.dict2str(opt))
# Load test image
im = ToTensor()(Image.open(opt['image'])) * 2 - 1
_, h, w = im.shape
if h % 2 == 1:
im = im[:,1:,:]
h = h-1
if w % 2 == 1:
im = im[:,:,1:]
w = w-1
dh, dw = (h - 32 * (h // 32)) // 2, (w - 32 * (w // 32)) // 2
if dh > 0:
im = im[:,dh:-dh]
if dw > 0:
im = im[:,:,dw:-dw]
im = im[:3].unsqueeze(0)
if audio_mode:
im = load_audio_from_wav(opt['image'], opt['sample_rate'])
im = im[:, :(im.shape[1]//4096)*4096]
# Hack to reduce memory usage (but cuts off sample):
im = im[:, :40960]
else:
im = ToTensor()(Image.open(opt['image'])) * 2 - 1
_, h, w = im.shape
if h % 2 == 1:
im = im[:,1:,:]
h = h-1
if w % 2 == 1:
im = im[:,:,1:]
w = w-1
dh, dw = (h - 32 * (h // 32)) // 2, (w - 32 * (w // 32)) // 2
if dh > 0:
im = im[:,dh:-dh]
if dw > 0:
im = im[:,:,dw:-dw]
im = im[:3].unsqueeze(0)
# Build the corruption indexes we are going to use.
correction_factors = opt['correction_factor']
# Build the corruption indexes we are going to use.
correction_factors = opt['correction_factor']
opt['steps']['generator']['injectors']['visual_debug']['zero_noise'] = False
#opt['steps']['generator']['injectors']['visual_debug']['zero_noise'] = False
model = ExtensibleTrainer(opt)
results_dir = osp.join(opt['path']['results_root'], os.path.basename(opt['image']))
util.mkdir(results_dir)
for i in range(10):
data = {
'hq': im.to('cuda'),
'corruption_entropy': torch.tensor([correction_factors], device='cuda',
dtype=torch.float),
'GT_path': opt['image']
}
forward_pass(model, data, results_dir, i)
if audio_mode:
data = {
'clip': im.to('cuda'),
'GT_path': opt['image']
}
else:
data = {
'hq': im.to('cuda'),
'corruption_entropy': torch.tensor([correction_factors], device='cuda',
dtype=torch.float),
'GT_path': opt['image']
}
forward_pass(model, data, results_dir, i, audio_mode)

View File

@ -97,14 +97,21 @@ class GaussianDiffusionInferenceInjector(Injector):
model_inputs = {k: state[v][:self.output_batch_size] for k, v in self.model_input_keys.items()}
gen.eval()
with torch.no_grad():
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
if 'low_res' in model_inputs.keys():
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
dev = model_inputs['low_res'].device
elif 'spectrogram' in model_inputs.keys():
output_shape = (self.output_batch_size, 1, model_inputs['spectrogram'].shape[-1]*256)
dev = model_inputs['spectrogram'].device
else:
raise NotImplementedError
noise = None
if self.noise_style == 'zero':
noise = torch.zeros(output_shape, device=model_inputs['low_res'].device)
noise = torch.zeros(output_shape, device=dev)
elif self.noise_style == 'fixed':
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
self.fixed_noise = torch.randn(output_shape, device=model_inputs['low_res'].device)
self.fixed_noise = torch.randn(output_shape, device=dev)
noise = self.fixed_noise
gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True)
if self.undo_n1_to_1: