diff --git a/codes/models/diffusion/rrdb_diffusion.py b/codes/models/diffusion/rrdb_diffusion.py index 0cbbe8ab..9deab2f3 100644 --- a/codes/models/diffusion/rrdb_diffusion.py +++ b/codes/models/diffusion/rrdb_diffusion.py @@ -173,13 +173,19 @@ class RRDBNet(nn.Module): default_init_weights(m, 1.0) default_init_weights(self.conv_last, 0) - def forward(self, x, timesteps, low_res=None): + def forward(self, x, timesteps, low_res, correction_factors=None): emb = self.time_embed(timestep_embedding(timesteps, self.mid_channels)) _, _, new_height, new_width = x.shape upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") x = torch.cat([x, upsampled], dim=1) + if correction_factors is not None: + correction_factors = correction_factors.view(x.shape[0], -1, 1, 1).repeat(1, 1, new_height, new_width) + else: + correction_factors = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device) + x = torch.cat([x, correction_factors], dim=1) + d1 = self.input_block(x) d2 = self.down1(d1) feat = self.down2(d2) diff --git a/codes/scripts/diffusion/diffusion_correction_surfer.py b/codes/scripts/diffusion/diffusion_correction_surfer.py index 05eb25ce..0a319ba0 100644 --- a/codes/scripts/diffusion/diffusion_correction_surfer.py +++ b/codes/scripts/diffusion/diffusion_correction_surfer.py @@ -76,7 +76,7 @@ if __name__ == "__main__": im = im[:,dh:-dh] if dw > 0: im = im[:,:,dw:-dw] - im = im.unsqueeze(0) + im = im[:3].unsqueeze(0) # Build the corruption indexes we are going to use. jpegs = list(numpy.arange(opt['min_jpeg_correction'], opt['max_jpeg_correction'], opt['jpeg_correction_step_size'])) diff --git a/codes/scripts/diffusion/diffusion_noise_surfer.py b/codes/scripts/diffusion/diffusion_noise_surfer.py index 805bb8c5..51bcc4f7 100644 --- a/codes/scripts/diffusion/diffusion_noise_surfer.py +++ b/codes/scripts/diffusion/diffusion_noise_surfer.py @@ -75,7 +75,7 @@ if __name__ == "__main__": im = im[:,dh:-dh] if dw > 0: im = im[:,:,dw:-dw] - im = im.unsqueeze(0) + im = im[:3].unsqueeze(0) # Build the corruption indexes we are going to use. correction_factors = opt['correction_factor'] diff --git a/codes/scripts/stitch_images.py b/codes/scripts/stitch_images.py new file mode 100644 index 00000000..c63117eb --- /dev/null +++ b/codes/scripts/stitch_images.py @@ -0,0 +1,20 @@ +import glob + +import torch +import torchvision +from PIL import Image +from torchvision.transforms import ToTensor + +if __name__ == '__main__': + imfolder = 'F:\\dlas\\results\\test_diffusion_unet\\imgset5' + cols, rows = 10, 5 + images = glob.glob(f'{imfolder}/*.png') + output = None + for r in range(rows): + for c in range(cols): + im = ToTensor()(Image.open(next(images))) + if output is None: + c, h, w = im.shape + output = torch.zeros(c, h * rows, w * cols) + output[:,r*h:(r+1)*h,c*w:(c+1)*w] = im + torchvision.utils.save_image(output, "out.png") \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index f95e4baf..917be288 100644 --- a/codes/train.py +++ b/codes/train.py @@ -299,7 +299,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_sm.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_xstart.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index be5273cf..69efda92 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -54,7 +54,7 @@ class GaussianDiffusionInferenceInjector(Injector): self.sampling_fn = self.diffusion.ddim_sample_loop if use_ddim else self.diffusion.p_sample_loop self.model_input_keys = opt_get(opt, ['model_input_keys'], []) self.use_ema_model = opt_get(opt, ['use_ema'], False) - self.zero_noise = opt_get(opt, ['zero_noise'], False) + self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random' def forward(self, state): if self.use_ema_model: @@ -66,7 +66,13 @@ class GaussianDiffusionInferenceInjector(Injector): with torch.no_grad(): output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor, model_inputs['low_res'].shape[-1] * self.output_scale_factor) - noise = torch.zeros(output_shape, device=model_inputs['low_res'].device) if self.zero_noise else None + noise = None + if self.noise_style == 'zero': + noise = torch.zeros(output_shape, device=model_inputs['low_res'].device) + elif self.noise_style == 'fixed': + if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape: + self.fixed_noise = torch.randn(output_shape, device=model_inputs['low_res'].device) + noise = self.fixed_noise gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True) if self.undo_n1_to_1: gen = (gen + 1) / 2