Misc fixes for diffusion nets

This commit is contained in:
James Betker 2021-06-21 10:38:07 -06:00
parent 8e3a33e001
commit e7890dc0ba
6 changed files with 38 additions and 6 deletions

View File

@ -173,13 +173,19 @@ class RRDBNet(nn.Module):
default_init_weights(m, 1.0)
default_init_weights(self.conv_last, 0)
def forward(self, x, timesteps, low_res=None):
def forward(self, x, timesteps, low_res, correction_factors=None):
emb = self.time_embed(timestep_embedding(timesteps, self.mid_channels))
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = torch.cat([x, upsampled], dim=1)
if correction_factors is not None:
correction_factors = correction_factors.view(x.shape[0], -1, 1, 1).repeat(1, 1, new_height, new_width)
else:
correction_factors = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device)
x = torch.cat([x, correction_factors], dim=1)
d1 = self.input_block(x)
d2 = self.down1(d1)
feat = self.down2(d2)

View File

@ -76,7 +76,7 @@ if __name__ == "__main__":
im = im[:,dh:-dh]
if dw > 0:
im = im[:,:,dw:-dw]
im = im.unsqueeze(0)
im = im[:3].unsqueeze(0)
# Build the corruption indexes we are going to use.
jpegs = list(numpy.arange(opt['min_jpeg_correction'], opt['max_jpeg_correction'], opt['jpeg_correction_step_size']))

View File

@ -75,7 +75,7 @@ if __name__ == "__main__":
im = im[:,dh:-dh]
if dw > 0:
im = im[:,:,dw:-dw]
im = im.unsqueeze(0)
im = im[:3].unsqueeze(0)
# Build the corruption indexes we are going to use.
correction_factors = opt['correction_factor']

View File

@ -0,0 +1,20 @@
import glob
import torch
import torchvision
from PIL import Image
from torchvision.transforms import ToTensor
if __name__ == '__main__':
imfolder = 'F:\\dlas\\results\\test_diffusion_unet\\imgset5'
cols, rows = 10, 5
images = glob.glob(f'{imfolder}/*.png')
output = None
for r in range(rows):
for c in range(cols):
im = ToTensor()(Image.open(next(images)))
if output is None:
c, h, w = im.shape
output = torch.zeros(c, h * rows, w * cols)
output[:,r*h:(r+1)*h,c*w:(c+1)*w] = im
torchvision.utils.save_image(output, "out.png")

View File

@ -299,7 +299,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_sm.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_xstart.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -54,7 +54,7 @@ class GaussianDiffusionInferenceInjector(Injector):
self.sampling_fn = self.diffusion.ddim_sample_loop if use_ddim else self.diffusion.p_sample_loop
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
self.use_ema_model = opt_get(opt, ['use_ema'], False)
self.zero_noise = opt_get(opt, ['zero_noise'], False)
self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
def forward(self, state):
if self.use_ema_model:
@ -66,7 +66,13 @@ class GaussianDiffusionInferenceInjector(Injector):
with torch.no_grad():
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
noise = torch.zeros(output_shape, device=model_inputs['low_res'].device) if self.zero_noise else None
noise = None
if self.noise_style == 'zero':
noise = torch.zeros(output_shape, device=model_inputs['low_res'].device)
elif self.noise_style == 'fixed':
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
self.fixed_noise = torch.randn(output_shape, device=model_inputs['low_res'].device)
noise = self.fixed_noise
gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True)
if self.undo_n1_to_1:
gen = (gen + 1) / 2