diff --git a/codes/scripts/diffusion/diffusion_noise_surfer.py b/codes/scripts/diffusion/diffusion_noise_surfer.py index 51bcc4f7..f6ebefd0 100644 --- a/codes/scripts/diffusion/diffusion_noise_surfer.py +++ b/codes/scripts/diffusion/diffusion_noise_surfer.py @@ -8,11 +8,13 @@ from collections import OrderedDict import numpy from PIL import Image +from scipy.io import wavfile from torchvision.transforms import ToTensor import utils import utils.options as option import utils.util as util +from data.audio.wavfile_dataset import load_audio_from_wav from trainer.ExtensibleTrainer import ExtensibleTrainer from data import create_dataset, create_dataloader from tqdm import tqdm @@ -21,7 +23,7 @@ import numpy as np # A rough copy of test.py that "surfs" along a set of random noise priors to show the affect of gaussian noise on the results. -def forward_pass(model, data, output_dir, spacing): +def forward_pass(model, data, output_dir, spacing, audio_mode): with torch.no_grad(): model.feed_data(data, 0) model.test() @@ -29,13 +31,17 @@ def forward_pass(model, data, output_dir, spacing): visuals = model.get_current_visuals()['rlt'].cpu() img_path = data['GT_path'][0] img_name = osp.splitext(osp.basename(img_path))[0] - sr_img = util.tensor2img(visuals[0]) # uint8 + sr_img = visuals[0] # save images suffixes = [f'_{int(spacing)}'] for suffix in suffixes: - save_img_path = osp.join(output_dir, img_name + suffix + '.png') - util.save_img(sr_img, save_img_path) + if audio_mode: + save_img_path = osp.join(output_dir, img_name + suffix + '.wav') + wavfile.write(osp.join(output_dir, save_img_path), 22050, sr_img[0].cpu().numpy()) + else: + save_img_path = osp.join(output_dir, img_name + suffix + '.png') + util.save_img(util.tensor2img(sr_img), save_img_path) if __name__ == "__main__": @@ -45,10 +51,11 @@ if __name__ == "__main__": np.random.seed(5555) #### options + audio_mode = True # Whether to render audio or images. torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt @@ -62,33 +69,45 @@ if __name__ == "__main__": logger.info(option.dict2str(opt)) # Load test image - im = ToTensor()(Image.open(opt['image'])) * 2 - 1 - _, h, w = im.shape - if h % 2 == 1: - im = im[:,1:,:] - h = h-1 - if w % 2 == 1: - im = im[:,:,1:] - w = w-1 - dh, dw = (h - 32 * (h // 32)) // 2, (w - 32 * (w // 32)) // 2 - if dh > 0: - im = im[:,dh:-dh] - if dw > 0: - im = im[:,:,dw:-dw] - im = im[:3].unsqueeze(0) + if audio_mode: + im = load_audio_from_wav(opt['image'], opt['sample_rate']) + im = im[:, :(im.shape[1]//4096)*4096] + # Hack to reduce memory usage (but cuts off sample): + im = im[:, :40960] + else: + im = ToTensor()(Image.open(opt['image'])) * 2 - 1 + _, h, w = im.shape + if h % 2 == 1: + im = im[:,1:,:] + h = h-1 + if w % 2 == 1: + im = im[:,:,1:] + w = w-1 + dh, dw = (h - 32 * (h // 32)) // 2, (w - 32 * (w // 32)) // 2 + if dh > 0: + im = im[:,dh:-dh] + if dw > 0: + im = im[:,:,dw:-dw] + im = im[:3].unsqueeze(0) - # Build the corruption indexes we are going to use. - correction_factors = opt['correction_factor'] + # Build the corruption indexes we are going to use. + correction_factors = opt['correction_factor'] - opt['steps']['generator']['injectors']['visual_debug']['zero_noise'] = False + #opt['steps']['generator']['injectors']['visual_debug']['zero_noise'] = False model = ExtensibleTrainer(opt) results_dir = osp.join(opt['path']['results_root'], os.path.basename(opt['image'])) util.mkdir(results_dir) for i in range(10): - data = { - 'hq': im.to('cuda'), - 'corruption_entropy': torch.tensor([correction_factors], device='cuda', - dtype=torch.float), - 'GT_path': opt['image'] - } - forward_pass(model, data, results_dir, i) + if audio_mode: + data = { + 'clip': im.to('cuda'), + 'GT_path': opt['image'] + } + else: + data = { + 'hq': im.to('cuda'), + 'corruption_entropy': torch.tensor([correction_factors], device='cuda', + dtype=torch.float), + 'GT_path': opt['image'] + } + forward_pass(model, data, results_dir, i, audio_mode) diff --git a/codes/scripts/diffusion/diffusion_sampler.py b/codes/scripts/diffusion/diffusion_sampler.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 2d4969d9..f9fa05ca 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -97,14 +97,21 @@ class GaussianDiffusionInferenceInjector(Injector): model_inputs = {k: state[v][:self.output_batch_size] for k, v in self.model_input_keys.items()} gen.eval() with torch.no_grad(): - output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor, - model_inputs['low_res'].shape[-1] * self.output_scale_factor) + if 'low_res' in model_inputs.keys(): + output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor, + model_inputs['low_res'].shape[-1] * self.output_scale_factor) + dev = model_inputs['low_res'].device + elif 'spectrogram' in model_inputs.keys(): + output_shape = (self.output_batch_size, 1, model_inputs['spectrogram'].shape[-1]*256) + dev = model_inputs['spectrogram'].device + else: + raise NotImplementedError noise = None if self.noise_style == 'zero': - noise = torch.zeros(output_shape, device=model_inputs['low_res'].device) + noise = torch.zeros(output_shape, device=dev) elif self.noise_style == 'fixed': if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape: - self.fixed_noise = torch.randn(output_shape, device=model_inputs['low_res'].device) + self.fixed_noise = torch.randn(output_shape, device=dev) noise = self.fixed_noise gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True) if self.undo_n1_to_1: