DL-Art-School/codes/models/SPSR_modules/sampler.py

95 lines
3.6 KiB
Python
Raw Normal View History

import random
import torch
import numpy as np
def _get_random_crop_indices(crop_region, crop_size):
'''
crop_region: (strat_y, end_y, start_x, end_x)
crop_size: (y, x)
'''
region_size = (crop_region[1] - crop_region[0], crop_region[3] - crop_region[2])
if region_size[0] < crop_size[0] or region_size[1] < crop_size[1]:
print(region_size, crop_size)
assert region_size[0] >= crop_size[0] and region_size[1] >= crop_size[1]
if region_size[0] == crop_size[0]:
start_y = crop_region[0]
else:
start_y = random.choice(range(crop_region[0], crop_region[1] - crop_size[0]))
if region_size[1] == crop_size[1]:
start_x = crop_region[2]
else:
start_x = random.choice(range(crop_region[2], crop_region[3] - crop_size[1]))
return start_y, start_y + crop_size[0], start_x, start_x + crop_size[1]
def _get_adaptive_crop_indices(crop_region, crop_size, num_candidate, dist_map, min_diff=False):
candidates = [_get_random_crop_indices(crop_region, crop_size) for _ in range(num_candidate)]
max_choice = candidates[0]
min_choice = candidates[0]
max_dist = 0
min_dist = np.infty
with torch.no_grad():
for c in candidates:
start_y, end_y, start_x, end_x = c
dist = torch.sum(dist_map[start_y: end_y, start_x: end_x])
if dist > max_dist:
max_dist = dist
max_choice = c
if dist < min_dist:
min_dist = dist
min_choice = c
if min_diff:
return min_choice
else:
return max_choice
def get_split_list(divisor, dividend):
split_list = [dividend // divisor for _ in range(divisor - 1)]
split_list.append(dividend - (dividend // divisor) * (divisor - 1))
return split_list
def random_sampler(pic_size, crop_dict):
crop_region = (0, pic_size[0], 0, pic_size[1])
crop_res_dict = {}
for k, v in crop_dict.items():
crop_size = (int(k), int(k))
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
return crop_res_dict
def region_sampler(crop_region, crop_dict):
crop_res_dict = {}
for k, v in crop_dict.items():
crop_size = (int(k), int(k))
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(v)]
return crop_res_dict
def adaptive_sampler(pic_size, crop_dict, num_candidate_dict, dist_map, min_diff=False):
crop_region = (0, pic_size[0], 0, pic_size[1])
crop_res_dict = {}
for k, v in crop_dict.items():
crop_size = (int(k), int(k))
crop_res_dict[k] = [_get_adaptive_crop_indices(crop_region, crop_size, num_candidate_dict[k], dist_map, min_diff) for _ in range(v)]
return crop_res_dict
# TODO more flexible
def pyramid_sampler(pic_size, crop_dict):
crop_res_dict = {}
sorted_key = list(crop_dict.keys())
sorted_key.sort(key=lambda x: int(x), reverse=True)
k = sorted_key[0]
crop_size = (int(k), int(k))
crop_region = (0, pic_size[0], 0, pic_size[1])
crop_res_dict[k] = [_get_random_crop_indices(crop_region, crop_size) for _ in range(crop_dict[k])]
for i in range(1, len(sorted_key)):
crop_res_dict[sorted_key[i]] = []
afore_num = crop_dict[sorted_key[i-1]]
new_num = crop_dict[sorted_key[i]]
split_list = get_split_list(afore_num, new_num)
crop_size = (int(sorted_key[i]), int(sorted_key[i]))
for j in range(len(split_list)):
crop_region = crop_res_dict[sorted_key[i-1]][j]
crop_res_dict[sorted_key[i]].extend([_get_random_crop_indices(crop_region, crop_size) for _ in range(split_list[j])])
return crop_res_dict