forked from mrq/DL-Art-School
Misc fixes for diffusion nets
This commit is contained in:
parent
8e3a33e001
commit
e7890dc0ba
|
@ -173,13 +173,19 @@ class RRDBNet(nn.Module):
|
|||
default_init_weights(m, 1.0)
|
||||
default_init_weights(self.conv_last, 0)
|
||||
|
||||
def forward(self, x, timesteps, low_res=None):
|
||||
def forward(self, x, timesteps, low_res, correction_factors=None):
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.mid_channels))
|
||||
|
||||
_, _, new_height, new_width = x.shape
|
||||
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
|
||||
x = torch.cat([x, upsampled], dim=1)
|
||||
|
||||
if correction_factors is not None:
|
||||
correction_factors = correction_factors.view(x.shape[0], -1, 1, 1).repeat(1, 1, new_height, new_width)
|
||||
else:
|
||||
correction_factors = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device)
|
||||
x = torch.cat([x, correction_factors], dim=1)
|
||||
|
||||
d1 = self.input_block(x)
|
||||
d2 = self.down1(d1)
|
||||
feat = self.down2(d2)
|
||||
|
|
|
@ -76,7 +76,7 @@ if __name__ == "__main__":
|
|||
im = im[:,dh:-dh]
|
||||
if dw > 0:
|
||||
im = im[:,:,dw:-dw]
|
||||
im = im.unsqueeze(0)
|
||||
im = im[:3].unsqueeze(0)
|
||||
|
||||
# Build the corruption indexes we are going to use.
|
||||
jpegs = list(numpy.arange(opt['min_jpeg_correction'], opt['max_jpeg_correction'], opt['jpeg_correction_step_size']))
|
||||
|
|
|
@ -75,7 +75,7 @@ if __name__ == "__main__":
|
|||
im = im[:,dh:-dh]
|
||||
if dw > 0:
|
||||
im = im[:,:,dw:-dw]
|
||||
im = im.unsqueeze(0)
|
||||
im = im[:3].unsqueeze(0)
|
||||
|
||||
# Build the corruption indexes we are going to use.
|
||||
correction_factors = opt['correction_factor']
|
||||
|
|
20
codes/scripts/stitch_images.py
Normal file
20
codes/scripts/stitch_images.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import glob
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
if __name__ == '__main__':
|
||||
imfolder = 'F:\\dlas\\results\\test_diffusion_unet\\imgset5'
|
||||
cols, rows = 10, 5
|
||||
images = glob.glob(f'{imfolder}/*.png')
|
||||
output = None
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
im = ToTensor()(Image.open(next(images)))
|
||||
if output is None:
|
||||
c, h, w = im.shape
|
||||
output = torch.zeros(c, h * rows, w * cols)
|
||||
output[:,r*h:(r+1)*h,c*w:(c+1)*w] = im
|
||||
torchvision.utils.save_image(output, "out.png")
|
|
@ -299,7 +299,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_sm.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_unet_diffusion_xstart.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -54,7 +54,7 @@ class GaussianDiffusionInferenceInjector(Injector):
|
|||
self.sampling_fn = self.diffusion.ddim_sample_loop if use_ddim else self.diffusion.p_sample_loop
|
||||
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
||||
self.use_ema_model = opt_get(opt, ['use_ema'], False)
|
||||
self.zero_noise = opt_get(opt, ['zero_noise'], False)
|
||||
self.noise_style = opt_get(opt, ['noise_type'], 'random') # 'zero', 'fixed' or 'random'
|
||||
|
||||
def forward(self, state):
|
||||
if self.use_ema_model:
|
||||
|
@ -66,7 +66,13 @@ class GaussianDiffusionInferenceInjector(Injector):
|
|||
with torch.no_grad():
|
||||
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
|
||||
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
|
||||
noise = torch.zeros(output_shape, device=model_inputs['low_res'].device) if self.zero_noise else None
|
||||
noise = None
|
||||
if self.noise_style == 'zero':
|
||||
noise = torch.zeros(output_shape, device=model_inputs['low_res'].device)
|
||||
elif self.noise_style == 'fixed':
|
||||
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
|
||||
self.fixed_noise = torch.randn(output_shape, device=model_inputs['low_res'].device)
|
||||
noise = self.fixed_noise
|
||||
gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True)
|
||||
if self.undo_n1_to_1:
|
||||
gen = (gen + 1) / 2
|
||||
|
|
Loading…
Reference in New Issue
Block a user