diff --git a/codes/models/SwitchedResidualGenerator_arch.py b/codes/models/SwitchedResidualGenerator_arch.py deleted file mode 100644 index 433f3621..00000000 --- a/codes/models/SwitchedResidualGenerator_arch.py +++ /dev/null @@ -1,594 +0,0 @@ -import functools -import os -from collections import OrderedDict - -import torch -import torch.nn.functional as F -import torchvision -from torch import nn - -from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, ExpansionBlock2, MultiConvBlock -from models.switched_conv.switched_conv import BareConvSwitch, compute_attention_specificity, AttentionNorm -from models.switched_conv.switched_conv_util import save_attention_to_image_rgb -from trainer.networks import register_model -from utils.util import checkpoint - - -# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation -# Doubles the input filter count. -class HalvingProcessingBlock(nn.Module): - def __init__(self, filters): - super(HalvingProcessingBlock, self).__init__() - self.bnconv1 = ConvGnSilu(filters, filters * 2, stride=2, norm=False, bias=False) - self.bnconv2 = ConvGnSilu(filters * 2, filters * 2, norm=True, bias=False) - - def forward(self, x): - x = self.bnconv1(x) - return self.bnconv2(x) - - -# This is a classic u-net architecture with the goal of assigning each individual pixel an individual transform -# switching set. -class ConvBasisMultiplexer(nn.Module): - def __init__(self, input_channels, base_filters, reductions, processing_depth, multiplexer_channels, use_gn=True, use_exp2=False): - super(ConvBasisMultiplexer, self).__init__() - self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True) - self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(reductions)]) - reduction_filters = base_filters * 2 ** reductions - self.processing_blocks = nn.Sequential(OrderedDict([('block%i' % (i,), ConvGnSilu(reduction_filters, reduction_filters, bias=False)) for i in range(processing_depth)])) - if use_exp2: - self.expansion_blocks = nn.ModuleList([ExpansionBlock2(reduction_filters // (2 ** i)) for i in range(reductions)]) - else: - self.expansion_blocks = nn.ModuleList([ExpansionBlock(reduction_filters // (2 ** i)) for i in range(reductions)]) - - gap = base_filters - multiplexer_channels - cbl1_out = ((base_filters - (gap // 2)) // 4) * 4 # Must be multiples of 4 to use with group norm. - self.cbl1 = ConvGnSilu(base_filters, cbl1_out, norm=use_gn, bias=False, num_groups=4) - cbl2_out = ((base_filters - (3 * gap // 4)) // 4) * 4 - self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, norm=use_gn, bias=False, num_groups=4) - self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, norm=False) - - def forward(self, x): - x = self.filter_conv(x) - reduction_identities = [] - for b in self.reduction_blocks: - reduction_identities.append(x) - x = b(x) - x = self.processing_blocks(x) - for i, b in enumerate(self.expansion_blocks): - x = b(x, reduction_identities[-i - 1]) - - x = self.cbl1(x) - x = self.cbl2(x) - x = self.cbl3(x) - return x - - -# torch.gather() which operates across 2d images. -def gather_2d(input, index): - b, c, h, w = input.shape - nodim = input.view(b, c, h * w) - ind_nd = index[:, 0]*w + index[:, 1] - ind_nd = ind_nd.unsqueeze(1) - ind_nd = ind_nd.repeat((1, c)) - ind_nd = ind_nd.unsqueeze(2) - result = torch.gather(nodim, dim=2, index=ind_nd) - result = result.squeeze() - if b == 1: - result = result.unsqueeze(0) - return result - - -class ConfigurableSwitchComputer(nn.Module): - def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm, - post_transform_block=None, - init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False, post_switch_conv=True, - anorm_multiplier=16): - super(ConfigurableSwitchComputer, self).__init__() - - tc = transform_count - self.multiplexer = multiplexer_net(tc) - - if pre_transform_block: - self.pre_transform = pre_transform_block() - else: - self.pre_transform = None - self.transforms = nn.ModuleList([transform_block() for _ in range(transform_count)]) - self.add_noise = add_scalable_noise_to_transforms - self.feed_transforms_into_multiplexer = feed_transforms_into_multiplexer - self.noise_scale = nn.Parameter(torch.full((1,), float(1e-3))) - - # And the switch itself, including learned scalars - self.switch = BareConvSwitch(initial_temperature=init_temp, attention_norm=AttentionNorm(transform_count, accumulator_size=anorm_multiplier * transform_count) if attention_norm else None) - self.switch_scale = nn.Parameter(torch.full((1,), float(1))) - self.post_transform_block = post_transform_block - if post_switch_conv: - self.post_switch_conv = ConvBnLelu(base_filters, base_filters, norm=False, bias=True) - # The post_switch_conv gets a low scale initially. The network can decide to magnify it (or not) - # depending on its needs. - self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) - else: - self.post_switch_conv = None - self.update_norm = True - - def set_update_attention_norm(self, set_val): - self.update_norm = set_val - - # Regarding inputs: it is acceptable to pass in a tuple/list as an input for (x), but the first element - # *must* be the actual parameter that gets fed through the network - it is assumed to be the identity. - def forward(self, x, att_in=None, identity=None, output_attention_weights=True, fixed_scale=1, do_checkpointing=False, - output_att_logits=False): - if isinstance(x, tuple): - x1 = x[0] - else: - x1 = x - - if att_in is None: - att_in = x - - if identity is None: - identity = x1 - - if self.add_noise: - rand_feature = torch.randn_like(x1) * self.noise_scale - if isinstance(x, tuple): - x = (x1 + rand_feature,) + x[1:] - else: - x = x1 + rand_feature - - if not isinstance(x, tuple): - x = (x,) - if self.pre_transform: - x = self.pre_transform(*x) - if not isinstance(x, tuple): - x = (x,) - if do_checkpointing: - xformed = [checkpoint(t, *x) for t in self.transforms] - else: - xformed = [t(*x) for t in self.transforms] - - if not isinstance(att_in, tuple): - att_in = (att_in,) - if self.feed_transforms_into_multiplexer: - att_in = att_in + (torch.stack(xformed, dim=1),) - if do_checkpointing: - m = checkpoint(self.multiplexer, *att_in) - else: - m = self.multiplexer(*att_in) - - # It is assumed that [xformed] and [m] are collapsed into tensors at this point. - outputs, attention, att_logits = self.switch(xformed, m, True, self.update_norm, output_attention_logits=True) - if self.post_transform_block is not None: - outputs = self.post_transform_block(outputs) - - outputs = identity + outputs * self.switch_scale * fixed_scale - if self.post_switch_conv is not None: - outputs = outputs + self.post_switch_conv(outputs) * self.psc_scale * fixed_scale - if output_attention_weights: - if output_att_logits: - return outputs, attention, att_logits - else: - return outputs, attention - else: - return outputs - - def set_temperature(self, temp): - self.switch.set_attention_temperature(temp) - - -class ConfigurableSwitchedResidualGenerator2(nn.Module): - def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, - trans_layers, transformation_filters, attention_norm, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, - heightened_final_step=50000, upsample_factor=1, - add_scalable_noise_to_transforms=False): - super(ConfigurableSwitchedResidualGenerator2, self).__init__() - switches = [] - self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True) - self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) - for _ in range(switch_depth): - multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts) - pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) - transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1) - switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=attention_norm, - transform_count=trans_counts, init_temp=initial_temp, - add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) - - self.switches = nn.ModuleList(switches) - self.transformation_counts = trans_counts - self.init_temperature = initial_temp - self.final_temperature_step = final_temperature_step - self.heightened_temp_min = heightened_temp_min - self.heightened_final_step = heightened_final_step - self.attentions = None - self.upsample_factor = upsample_factor - assert self.upsample_factor == 2 or self.upsample_factor == 4 - - def forward(self, x): - # This is a common bug when evaluating SRG2 generators. It needs to be configured properly in eval mode. Just fail. - if not self.train: - assert self.switches[0].switch.temperature == 1 - - x = self.initial_conv(x) - - self.attentions = [] - for i, sw in enumerate(self.switches): - x, att = sw.forward(x, True) - self.attentions.append(att) - - x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) - if self.upsample_factor > 2: - x = F.interpolate(x, scale_factor=2, mode="nearest") - x = self.upconv2(x) - x = self.final_conv(self.hr_conv(x)) - return x, x - - def set_temperature(self, temp): - [sw.set_temperature(temp) for sw in self.switches] - - def update_for_step(self, step, experiments_path='.'): - if self.attentions: - temp = max(1, - 1 + self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step) - if temp == 1 and self.heightened_final_step and step > self.final_temperature_step and \ - self.heightened_final_step != 1: - # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. - # without this, the attention specificity "spikes" incredibly fast in the last few iterations. - h_steps_total = self.heightened_final_step - self.final_temperature_step - h_steps_current = min(step - self.final_temperature_step, h_steps_total) - # The "gap" will represent the steps that need to be traveled as a linear function. - h_gap = 1 / self.heightened_temp_min - temp = h_gap * h_steps_current / h_steps_total - # Invert temperature to represent reality on this side of the curve - temp = 1 / temp - self.set_temperature(temp) - if step % 50 == 0: - output_path = os.path.join(experiments_path, "attention_maps", "a%i") - prefix = "attention_map_%i_%%i.png" % (step,) - [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] - - def get_debug_values(self, step): - temp = self.switches[0].switch.temperature - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {"switch_temperature": temp} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - return val - - -# Computes a linear latent by performing processing on the reference image and returning the filters of a single point, -# which should be centered on the image patch being processed. -# -# Output is base_filters * 8. -class ReferenceImageBranch(nn.Module): - def __init__(self, base_filters=64): - super(ReferenceImageBranch, self).__init__() - self.features = nn.Sequential(ConvGnSilu(4, base_filters, kernel_size=7, bias=True), - HalvingProcessingBlock(base_filters), - ConvGnSilu(base_filters*2, base_filters*2, activation=True, norm=True, bias=False), - HalvingProcessingBlock(base_filters*2), - ConvGnSilu(base_filters*4, base_filters*4, activation=True, norm=True, bias=False), - HalvingProcessingBlock(base_filters*4), - ConvGnSilu(base_filters*8, base_filters*8, activation=True, norm=True, bias=False), - ConvGnSilu(base_filters*8, base_filters*8, activation=True, norm=True, bias=False)) - - # center_point is a [b,2] long tensor describing the center point of where the patch was taken from the reference - # image. - def forward(self, x, center_point): - x = self.features(x) - return gather_2d(x, center_point // 8) # Divide by 8 to scale the center_point down. - - -# Mutiplexer that combines a structured embedding with a contextual switch input to guide alterations to that input. -# -# Implemented as basically a u-net which reduces the input into the same structural space as the embedding, combines the -# two, then expands back into the original feature space. -class EmbeddingMultiplexer(nn.Module): - # Note: reductions=2 if the encoder is using interpolated input, otherwise reductions=3. - def __init__(self, nf, multiplexer_channels, reductions=2): - super(EmbeddingMultiplexer, self).__init__() - self.embedding_process = MultiConvBlock(256, 256, 256, kernel_size=3, depth=3, norm=True) - - self.filter_conv = ConvGnSilu(nf, nf, activation=True, norm=False, bias=True) - self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(nf * 2 ** i) for i in range(reductions)]) - reduction_filters = nf * 2 ** reductions - self.processing_blocks = nn.Sequential( - ConvGnSilu(reduction_filters + 256, reduction_filters + 256, kernel_size=1, activation=True, norm=False, bias=True), - ConvGnSilu(reduction_filters + 256, reduction_filters + 128, kernel_size=1, activation=True, norm=True, bias=False), - ConvGnSilu(reduction_filters + 128, reduction_filters, kernel_size=3, activation=True, norm=True, bias=False), - ConvGnSilu(reduction_filters, reduction_filters, kernel_size=3, activation=True, norm=True, bias=False)) - self.expansion_blocks = nn.ModuleList([ExpansionBlock2(reduction_filters // (2 ** i)) for i in range(reductions)]) - - gap = nf - multiplexer_channels - cbl1_out = ((nf - (gap // 2)) // 4) * 4 # Must be multiples of 4 to use with group norm. - self.cbl1 = ConvGnSilu(nf, cbl1_out, norm=True, bias=False, num_groups=4) - cbl2_out = ((nf - (3 * gap // 4)) // 4) * 4 - self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, norm=True, bias=False, num_groups=4) - self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, norm=False) - - def forward(self, x, embedding): - x = self.filter_conv(x) - embedding = self.embedding_process(embedding) - - reduction_identities = [] - for b in self.reduction_blocks: - reduction_identities.append(x) - x = b(x) - x = self.processing_blocks(torch.cat([x, embedding], dim=1)) - for i, b in enumerate(self.expansion_blocks): - x = b(x, reduction_identities[-i - 1]) - - x = self.cbl1(x) - x = self.cbl2(x) - x = self.cbl3(x) - return x - - -class QueryKeyMultiplexer(nn.Module): - def __init__(self, nf, multiplexer_channels, embedding_channels=256, reductions=2): - super(QueryKeyMultiplexer, self).__init__() - - # Blocks used to create the query - self.input_process = ConvGnSilu(nf, nf, activation=True, norm=False, bias=True) - self.embedding_process = ConvGnSilu(embedding_channels, 256, activation=True, norm=False, bias=True) - self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(nf * 2 ** i) for i in range(reductions)]) - reduction_filters = nf * 2 ** reductions - self.processing_blocks = nn.Sequential( - ConvGnSilu(reduction_filters + 256, reduction_filters + 256, kernel_size=1, activation=True, norm=False, bias=True), - ConvGnSilu(reduction_filters + 256, reduction_filters + 128, kernel_size=1, activation=True, norm=True, bias=False), - ConvGnSilu(reduction_filters + 128, reduction_filters, kernel_size=3, activation=True, norm=True, bias=False), - ConvGnSilu(reduction_filters, reduction_filters, kernel_size=3, activation=True, norm=True, bias=False)) - self.expansion_blocks = nn.ModuleList([ExpansionBlock2(reduction_filters // (2 ** i)) for i in range(reductions)]) - - # Blocks used to create the key - self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True) - - # Postprocessing blocks. - self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=1, activation=True, norm=False, bias=False) - self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4) - self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False) - - def forward(self, x, embedding, transformations): - q = self.input_process(x) - embedding = self.embedding_process(embedding) - reduction_identities = [] - for b in self.reduction_blocks: - reduction_identities.append(q) - q = b(q) - q = self.processing_blocks(torch.cat([q, embedding], dim=1)) - for i, b in enumerate(self.expansion_blocks): - q = b(q, reduction_identities[-i - 1]) - - b, t, f, h, w = transformations.shape - k = transformations.view(b * t, f, h, w) - k = self.key_process(k) - - q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1).view(b * t, f, h, w) - v = self.query_key_combine(torch.cat([q, k], dim=1)) - - v = self.cbl1(v) - v = self.cbl2(v) - - return v.view(b, t, h, w) - - -class QueryKeyPyramidMultiplexer(nn.Module): - def __init__(self, nf, multiplexer_channels, reductions=3): - super(QueryKeyPyramidMultiplexer, self).__init__() - - # Blocks used to create the query - self.input_process = ConvGnSilu(nf, nf, activation=True, norm=False, bias=True) - self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(nf * 2 ** i) for i in range(reductions)]) - reduction_filters = nf * 2 ** reductions - self.processing_blocks = nn.Sequential(OrderedDict([('block%i' % (i,), ConvGnSilu(reduction_filters, reduction_filters, kernel_size=1, norm=True, bias=False)) for i in range(3)])) - self.expansion_blocks = nn.ModuleList([ExpansionBlock2(reduction_filters // (2 ** i)) for i in range(reductions)]) - - # Blocks used to create the key - self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True) - - # Postprocessing blocks. - self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=3, activation=True, norm=False, bias=False) - self.cbl0 = ConvGnSilu(nf, nf, kernel_size=3, activation=True, norm=True, bias=False) - self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4) - self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False) - - def forward(self, x, transformations): - q = self.input_process(x) - reduction_identities = [] - for b in self.reduction_blocks: - reduction_identities.append(q) - q = b(q) - q = self.processing_blocks(q) - for i, b in enumerate(self.expansion_blocks): - q = b(q, reduction_identities[-i - 1]) - - b, t, f, h, w = transformations.shape - k = transformations.view(b * t, f, h, w) - k = self.key_process(k) - - q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1).view(b * t, f, h, w) - v = self.query_key_combine(torch.cat([q, k], dim=1)) - v = self.cbl0(v) - v = self.cbl1(v) - v = self.cbl2(v) - - return v.view(b, t, h, w) - - -# Base class for models that utilize ConfigurableSwitchComputer. Provides basis functionality like logging -# switch temperature, distribution and images, as well as managing attention norms. -class SwitchModelBase(nn.Module): - def __init__(self, init_temperature=10, final_temperature_step=10000): - super(SwitchModelBase, self).__init__() - self.switches = [] # The implementing class is expected to set this to a list of all ConfigurableSwitchComputers. - self.attentions = [] # The implementing class is expected to set this in forward() to the output of the attention blocks. - self.lr = None # The implementing class is expected to set this to the input image fed into the generator. If not - # set, the attention logger will not output an image reference. - self.init_temperature = init_temperature - self.final_temperature_step = final_temperature_step - - def set_temperature(self, temp): - [sw.set_temperature(temp) for sw in self.switches] - - def update_for_step(self, step, experiments_path='.'): - # All-reduce the attention norm. - for sw in self.switches: - sw.switch.reduce_norm_params() - - temp = max(1, 1 + self.init_temperature * - (self.final_temperature_step - step) / self.final_temperature_step) - self.set_temperature(temp) - if step % 100 == 0: - output_path = os.path.join(experiments_path, "attention_maps") - prefix = "amap_%i_a%i_%%i.png" - [save_attention_to_image_rgb(output_path, self.attentions[i], self.attentions[i].shape[3], prefix % (step, i), step, - output_mag=False) for i in range(len(self.attentions))] - if self.lr is not None: - torchvision.utils.save_image(self.lr[:, :3], os.path.join(experiments_path, "attention_maps", - "amap_%i_base_image.png" % (step,))) - - # This is a bit awkward. We want this plot to show up in TB as a histogram, but we are getting an intensity - # plot out of the attention norm tensor. So we need to convert it back into a list of indexes, then feed into TB. - def compute_anorm_histogram(self): - intensities = [sw.switch.attention_norm.compute_buffer_norm().clone().detach().cpu() for sw in self.switches] - result = [] - for intensity in intensities: - intensity = intensity * 10 - bins = torch.tensor(list(range(len(intensity)))) - intensity = intensity.long() - result.append(bins.repeat_interleave(intensity, 0)) - return result - - def get_debug_values(self, step, net_name): - temp = self.switches[0].switch.temperature - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - anorms = self.compute_anorm_histogram() - val = {"switch_temperature": temp} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - val["switch_%i_attention_norm_histogram" % (i,)] = anorms[i] - return val - - -class ConfigurableSwitchedResidualGenerator2(nn.Module): - def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, - trans_layers, transformation_filters, attention_norm, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, - heightened_final_step=50000, upsample_factor=1, - add_scalable_noise_to_transforms=False): - super(ConfigurableSwitchedResidualGenerator2, self).__init__() - switches = [] - self.initial_conv = ConvBnLelu(3, transformation_filters, kernel_size=7, norm=False, activation=False, bias=True) - self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) - for _ in range(switch_depth): - multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts) - pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) - transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1) - switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - attention_norm=attention_norm, - transform_count=trans_counts, init_temp=initial_temp, - add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) - - self.switches = nn.ModuleList(switches) - self.transformation_counts = trans_counts - self.init_temperature = initial_temp - self.final_temperature_step = final_temperature_step - self.heightened_temp_min = heightened_temp_min - self.heightened_final_step = heightened_final_step - self.attentions = None - self.upsample_factor = upsample_factor - self.lr = None - assert self.upsample_factor == 2 or self.upsample_factor == 4 - - def forward(self, x): - self.lr = x.detach().cpu() - - # This is a common bug when evaluating SRG2 generators. It needs to be configured properly in eval mode. Just fail. - if not self.train: - assert self.switches[0].switch.temperature == 1 - - x = self.initial_conv(x) - - self.attentions = [] - for i, sw in enumerate(self.switches): - x, att = checkpoint(sw, x) - self.attentions.append(att) - - x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) - if self.upsample_factor > 2: - x = F.interpolate(x, scale_factor=2, mode="nearest") - x = self.upconv2(x) - x = self.final_conv(self.hr_conv(x)) - return x - - def set_temperature(self, temp): - [sw.set_temperature(temp) for sw in self.switches] - - def update_for_step(self, step, experiments_path='.'): - if self.attentions: - temp = max(1, - 1 + self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step) - if temp == 1 and self.heightened_final_step and step > self.final_temperature_step and \ - self.heightened_final_step != 1: - # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. - # without this, the attention specificity "spikes" incredibly fast in the last few iterations. - h_steps_total = self.heightened_final_step - self.final_temperature_step - h_steps_current = min(step - self.final_temperature_step, h_steps_total) - # The "gap" will represent the steps that need to be traveled as a linear function. - h_gap = 1 / self.heightened_temp_min - temp = h_gap * h_steps_current / h_steps_total - # Invert temperature to represent reality on this side of the curve - temp = 1 / temp - self.set_temperature(temp) - if step % 100 == 0: - output_path = os.path.join(experiments_path, "attention_maps") - prefix = "amap_%i_a%i_%%i.png" - [save_attention_to_image_rgb(output_path, self.attentions[i], self.attentions[i].shape[3], prefix % (step, i), step, - output_mag=False) for i in range(len(self.attentions))] - if self.lr is not None: - torchvision.utils.save_image(self.lr[:, :3], os.path.join(experiments_path, "attention_maps", - "amap_%i_base_image.png" % (step,))) - - def get_debug_values(self, step, net_name): - temp = self.switches[0].switch.temperature - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {"switch_temperature": temp} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - return val - -@register_model -def register_ConfigurableSwitchedResidualGenerator2(opt_net, opt): - return ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], - switch_filters=opt_net['switch_filters'], - switch_reductions=opt_net['switch_reductions'], - switch_processing_layers=opt_net[ - 'switch_processing_layers'], - trans_counts=opt_net['trans_counts'], - trans_kernel_sizes=opt_net['trans_kernel_sizes'], - trans_layers=opt_net['trans_layers'], - transformation_filters=opt_net['transformation_filters'], - attention_norm=opt_net['attention_norm'], - initial_temp=opt_net['temperature'], - final_temperature_step=opt_net['temperature_final_step'], - heightened_temp_min=opt_net['heightened_temp_min'], - heightened_final_step=opt_net['heightened_final_step'], - upsample_factor=scale, - add_scalable_noise_to_transforms=opt_net['add_noise'], - for_video=opt_net['for_video']) \ No newline at end of file diff --git a/codes/models/srg2_classic.py b/codes/models/srg2_classic.py deleted file mode 100644 index b0245023..00000000 --- a/codes/models/srg2_classic.py +++ /dev/null @@ -1,223 +0,0 @@ -import os - -import torch -import torchvision -from matplotlib import cm -from torch import nn -import torch.nn.functional as F -import functools -from collections import OrderedDict - -from models.SwitchedResidualGenerator_arch import HalvingProcessingBlock, ConfigurableSwitchComputer -from models.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock, MultiConvBlock -from models.switched_conv.switched_conv import BareConvSwitch, AttentionNorm -from trainer.networks import register_model -from utils.util import checkpoint - - -# This is a classic u-net architecture with the goal of assigning each individual pixel an individual transform -# switching set. -class ConvBasisMultiplexer(nn.Module): - def __init__(self, input_channels, base_filters, reductions, processing_depth, multiplexer_channels, use_gn=True): - super(ConvBasisMultiplexer, self).__init__() - self.filter_conv = ConvGnSilu(input_channels, base_filters, bias=True) - self.reduction_blocks = nn.ModuleList([HalvingProcessingBlock(base_filters * 2 ** i) for i in range(reductions)]) - reduction_filters = base_filters * 2 ** reductions - self.processing_blocks = nn.Sequential(OrderedDict([('block%i' % (i,), ConvGnSilu(reduction_filters, reduction_filters, bias=False)) for i in range(processing_depth)])) - self.expansion_blocks = nn.ModuleList([ExpansionBlock(reduction_filters // (2 ** i)) for i in range(reductions)]) - - gap = base_filters - multiplexer_channels - cbl1_out = ((base_filters - (gap // 2)) // 4) * 4 # Must be multiples of 4 to use with group norm. - self.cbl1 = ConvGnSilu(base_filters, cbl1_out, norm=use_gn, bias=False, num_groups=4) - cbl2_out = ((base_filters - (3 * gap // 4)) // 4) * 4 - self.cbl2 = ConvGnSilu(cbl1_out, cbl2_out, norm=use_gn, bias=False, num_groups=4) - self.cbl3 = ConvGnSilu(cbl2_out, multiplexer_channels, bias=True, norm=False) - - def forward(self, x): - x = self.filter_conv(x) - reduction_identities = [] - for b in self.reduction_blocks: - reduction_identities.append(x) - x = b(x) - x = self.processing_blocks(x) - for i, b in enumerate(self.expansion_blocks): - x = b(x, reduction_identities[-i - 1]) - - x = self.cbl1(x) - x = self.cbl2(x) - x = self.cbl3(x) - return x - - -def compute_attention_specificity(att_weights, topk=3): - att = att_weights.detach() - vals, indices = torch.topk(att, topk, dim=-1) - avg = torch.sum(vals, dim=-1) - avg = avg.flatten().mean() - return avg.item(), indices.flatten().detach() - - -# Copied from torchvision.utils.save_image. Allows specifying pixel format. -def save_image(tensor, fp, nrow=8, padding=2, - normalize=False, range=None, scale_each=False, pad_value=0, format=None, pix_format=None): - from PIL import Image - grid = torchvision.utils.make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, - normalize=normalize, range=range, scale_each=scale_each) - # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer - ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() - im = Image.fromarray(ndarr, mode=pix_format).convert('RGB') - im.save(fp, format=format) - - -def save_attention_to_image(folder, attention_out, attention_size, step, fname_part="map", l_mult=1.0): - magnitude, indices = torch.topk(attention_out, 1, dim=-1) - magnitude = magnitude.squeeze(3) - indices = indices.squeeze(3) - # indices is an integer tensor (b,w,h) where values are on the range [0,attention_size] - # magnitude is a float tensor (b,w,h) [0,1] representing the magnitude of that attention. - # Use HSV colorspace to show this. Hue is mapped to the indices, Lightness is mapped to intensity, - # Saturation is left fixed. - hue = indices.float() / attention_size - saturation = torch.full_like(hue, .8) - value = magnitude * l_mult - hsv_img = torch.stack([hue, saturation, value], dim=1) - - output_path=os.path.join(folder, "attention_maps", fname_part) - os.makedirs(output_path, exist_ok=True) - save_image(hsv_img, os.path.join(output_path, "attention_map_%i.png" % (step,)), pix_format="HSV") - - -def save_attention_to_image_rgb(output_folder, attention_out, attention_size, file_prefix, step, cmap_discrete_name='viridis'): - magnitude, indices = torch.topk(attention_out, 3, dim=-1) - magnitude = magnitude.cpu() - indices = indices.cpu() - magnitude /= torch.max(torch.abs(torch.min(magnitude)), torch.abs(torch.max(magnitude))) - colormap = cm.get_cmap(cmap_discrete_name, attention_size) - colormap_mag = cm.get_cmap(cmap_discrete_name) - os.makedirs(os.path.join(output_folder), exist_ok=True) - for i in range(3): - img = torch.tensor(colormap(indices[:,:,:,i].detach().numpy())) - img = img.permute((0, 3, 1, 2)) - save_image(img, os.path.join(output_folder, file_prefix + "_%i_%s.png" % (step, "rgb_%i" % (i,))), pix_format="RGBA") - - mag_image = torch.tensor(colormap_mag(magnitude[:,:,:,i].detach().numpy())) - mag_image = mag_image.permute((0, 3, 1, 2)) - save_image(mag_image, os.path.join(output_folder, file_prefix + "_%i_%s.png" % (step, "mag_%i" % (i,))), pix_format="RGBA") - - -class ConfigurableSwitchedResidualGenerator2(nn.Module): - def __init__(self, switch_depth, switch_filters, switch_reductions, switch_processing_layers, trans_counts, trans_kernel_sizes, - trans_layers, transformation_filters, initial_temp=20, final_temperature_step=50000, heightened_temp_min=1, - heightened_final_step=50000, upsample_factor=1, - add_scalable_noise_to_transforms=False, for_video=False): - super(ConfigurableSwitchedResidualGenerator2, self).__init__() - switches = [] - self.for_video = for_video - if for_video: - self.initial_conv = ConvBnLelu(6, transformation_filters, stride=upsample_factor, norm=False, activation=False, bias=True) - else: - self.initial_conv = ConvBnLelu(3, transformation_filters, norm=False, activation=False, bias=True) - self.upconv1 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.upconv2 = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.hr_conv = ConvBnLelu(transformation_filters, transformation_filters, norm=False, bias=True) - self.final_conv = ConvBnLelu(transformation_filters, 3, norm=False, activation=False, bias=True) - for _ in range(switch_depth): - multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, switch_processing_layers, trans_counts) - pretransform_fn = functools.partial(ConvBnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) - transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), transformation_filters, kernel_size=trans_kernel_sizes, depth=trans_layers, weight_init_factor=.1) - switches.append(ConfigurableSwitchComputer(transformation_filters, multiplx_fn, attention_norm=True, - pre_transform_block=pretransform_fn, transform_block=transform_fn, - transform_count=trans_counts, init_temp=initial_temp, - add_scalable_noise_to_transforms=add_scalable_noise_to_transforms)) - - self.switches = nn.ModuleList(switches) - self.transformation_counts = trans_counts - self.init_temperature = initial_temp - self.final_temperature_step = final_temperature_step - self.heightened_temp_min = heightened_temp_min - self.heightened_final_step = heightened_final_step - self.attentions = None - self.upsample_factor = upsample_factor - assert self.upsample_factor == 2 or self.upsample_factor == 4 - - def forward(self, x, ref=None): - if self.for_video: - x_lg = F.interpolate(x, scale_factor=self.upsample_factor, mode="bicubic") - if ref is None: - ref = torch.zeros_like(x_lg) - x_lg = torch.cat([x_lg, ref], dim=1) - else: - x_lg = x - x = self.initial_conv(x_lg) - - self.attentions = [] - for i, sw in enumerate(self.switches): - x, att = checkpoint(sw, x) - self.attentions.append(att) - - x = self.upconv1(F.interpolate(x, scale_factor=2, mode="nearest")) - if self.upsample_factor > 2: - x = F.interpolate(x, scale_factor=2, mode="nearest") - x = self.upconv2(x) - x = self.final_conv(self.hr_conv(x)) - return x - - def set_temperature(self, temp): - [sw.set_temperature(temp) for sw in self.switches] - - def update_for_step(self, step, experiments_path='.'): - if self.attentions: - temp = max(1, - 1 + self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step) - if temp == 1 and self.heightened_final_step and step > self.final_temperature_step and \ - self.heightened_final_step != 1: - # Once the temperature passes (1) it enters an inverted curve to match the linear curve from above. - # without this, the attention specificity "spikes" incredibly fast in the last few iterations. - h_steps_total = self.heightened_final_step - self.final_temperature_step - h_steps_current = min(step - self.final_temperature_step, h_steps_total) - # The "gap" will represent the steps that need to be traveled as a linear function. - h_gap = 1 / self.heightened_temp_min - temp = h_gap * h_steps_current / h_steps_total - # Invert temperature to represent reality on this side of the curve - temp = 1 / temp - self.set_temperature(temp) - if step % 50 == 0: - [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))] - - def get_debug_values(self, step, net_name): - temp = self.switches[0].switch.temperature - mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] - means = [i[0] for i in mean_hists] - hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] - val = {"switch_temperature": temp} - for i in range(len(means)): - val["switch_%i_specificity" % (i,)] = means[i] - val["switch_%i_histogram" % (i,)] = hists[i] - return val - - -class Interpolate(nn.Module): - def __init__(self, factor, mode="nearest"): - super(Interpolate, self).__init__() - self.factor = factor - self.mode = mode - - def forward(self, x): - return F.interpolate(x, scale_factor=self.factor, mode=self.mode) - -@register_model -def register_srg2classic(opt_net, opt): - return ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], - switch_filters=opt_net['switch_filters'], - switch_reductions=opt_net['switch_reductions'], - switch_processing_layers=opt_net['switch_processing_layers'], - trans_counts=opt_net['trans_counts'], - trans_kernel_sizes=opt_net['trans_kernel_sizes'], - trans_layers=opt_net['trans_layers'], - transformation_filters=opt_net['transformation_filters'], - initial_temp=opt_net['temperature'], - final_temperature_step=opt_net['temperature_final_step'], - heightened_temp_min=opt_net['heightened_temp_min'], - heightened_final_step=opt_net['heightened_final_step'], - upsample_factor=scale, - add_scalable_noise_to_transforms=opt_net['add_noise']) \ No newline at end of file diff --git a/codes/models/switched_conv b/codes/models/switched_conv deleted file mode 160000 index cb520afd..00000000 --- a/codes/models/switched_conv +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cb520afd4da97796bfca398feeef18a7bd18475c diff --git a/codes/models/vqvae/kmeans_mask_producer.py b/codes/models/vqvae/kmeans_mask_producer.py new file mode 100644 index 00000000..33d687d7 --- /dev/null +++ b/codes/models/vqvae/kmeans_mask_producer.py @@ -0,0 +1,48 @@ +import torch +from torch import nn +import torch.nn.functional as F +from torchvision.models.resnet import Bottleneck + +from models.pixel_level_contrastive_learning.resnet_unet import UResNet50 +from trainer.networks import register_model +from utils.kmeans import kmeans_predict +from utils.util import opt_get + + +class UResnetMaskProducer(nn.Module): + def __init__(self, pretrained_uresnet_path, kmeans_centroid_path, mask_scales=[.125,.25,.5,1]): + super().__init__() + _, centroids = torch.load(kmeans_centroid_path) + self.centroids = nn.Parameter(centroids) + self.ures = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda') + self.mask_scales = mask_scales + + sd = torch.load(pretrained_uresnet_path) + # An assumption is made that the state_dict came from a byol model. Strip out unnecessary weights.. + resnet_sd = {} + for k, v in sd.items(): + if 'target_encoder.net.' in k: + resnet_sd[k.replace('target_encoder.net.', '')] = v + + self.ures.load_state_dict(resnet_sd, strict=True) + self.ures.eval() + + def forward(self, x): + with torch.no_grad(): + latents = self.ures(x) + b,c,h,w = latents.shape + latents = latents.permute(0,2,3,1).reshape(b*h*w,c) + masks = kmeans_predict(latents, self.centroids).float() + masks = masks.reshape(b,1,h,w) + interpolated_masks = {} + for sf in self.mask_scales: + dim_h, dim_w = int(sf*x.shape[-2]), int(sf*x.shape[-1]) + imask = F.interpolate(masks, size=(dim_h,dim_w), mode="nearest") + interpolated_masks[dim_w] = imask.long() + return interpolated_masks + + +@register_model +def register_uresnet_mask_producer(opt_net, opt): + kw = opt_get(opt_net, ['kwargs'], {}) + return UResnetMaskProducer(**kw) diff --git a/codes/models/vqvae/scaled_weight_conv.py b/codes/models/vqvae/scaled_weight_conv.py new file mode 100644 index 00000000..a6c212da --- /dev/null +++ b/codes/models/vqvae/scaled_weight_conv.py @@ -0,0 +1,130 @@ +from typing import Optional, List + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd +from torch.nn.modules.utils import _ntuple +import torch.nn.functional as F + + +_pair = _ntuple(2) + + +# Indexes the
index of input=b,c,h,w,p by the long tensor index=b,1,h,w. Result is b,c,h,w. +# Frankly - IMO - this is what torch.gather should do. +def index_2d(input, index): + index = index.repeat(1,input.shape[1],1,1) + e = torch.eye(input.shape[-1], device=input.device) + result = e[index] * input + return result.sum(-1) + + +# Drop-in implementation of Conv2d that can apply masked scales&shifts to the convolution weights. +class ScaledWeightConv(_ConvNd): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + stride = 1, + padding = 0, + dilation = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + breadth: int = 8, + ): + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super().__init__( + in_channels, out_channels, _pair(kernel_size), stride, padding, dilation, + False, _pair(0), groups, bias, padding_mode) + + self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones(out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)]) + self.shifts = nn.ParameterList([nn.Parameter(torch.zeros(out_channels, in_channels, kernel_size, kernel_size)) for _ in range(breadth)]) + for w, s in zip(self.weight_scales, self.shifts): + w.FOR_SCALE_SHIFT = True + s.FOR_SCALE_SHIFT = True + # This should probably be configurable at some point. + for p in self.parameters(): + if not hasattr(p, "FOR_SCALE_SHIFT"): + p.DO_NOT_TRAIN = True + + def _weighted_conv_forward(self, input, weight): + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.bias, self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + def forward(self, input: Tensor, masks: dict) -> Tensor: + # This is an exceptionally inefficient way of achieving this functionality. The hope is that if this is any + # good at all, this can be made more efficient by performing a single conv pass with multiple masks. + weighted_convs = [self._weighted_conv_forward(input, self.weight * scale + shift) for scale, shift in zip(self.weight_scales, self.shifts)] + weighted_convs = torch.stack(weighted_convs, dim=-1) + + needed_mask = weighted_convs.shape[-2] + assert needed_mask in masks.keys() + + return index_2d(weighted_convs, masks[needed_mask]) + + +# Drop-in implementation of ConvTranspose2d that can apply masked scales&shifts to the convolution weights. +class ScaledWeightConvTranspose(_ConvTransposeNd): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + stride = 1, + padding = 0, + output_padding = 0, + groups: int = 1, + bias: bool = True, + dilation: int = 1, + padding_mode: str = 'zeros', + breadth: int = 8, + ): + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + super().__init__( + in_channels, out_channels, _pair(kernel_size), stride, padding, dilation, + True, output_padding, groups, bias, padding_mode) + + self.weight_scales = nn.ParameterList([nn.Parameter(torch.ones(in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)]) + self.shifts = nn.ParameterList([nn.Parameter(torch.zeros(in_channels, out_channels, kernel_size, kernel_size)) for _ in range(breadth)]) + for w, s in zip(self.weight_scales, self.shifts): + w.FOR_SCALE_SHIFT = True + s.FOR_SCALE_SHIFT = True + # This should probably be configurable at some point. + for nm, p in self.named_parameters(): + if nm == 'weight': + p.DO_NOT_TRAIN = True + + def _conv_transpose_forward(self, input, weight, output_size) -> Tensor: + if self.padding_mode != 'zeros': + raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') + + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) + + return F.conv_transpose2d( + input, weight, self.bias, self.stride, self.padding, + output_padding, self.groups, self.dilation) + + def forward(self, input: Tensor, masks: dict, output_size: Optional[List[int]] = None) -> Tensor: + # This is an exceptionally inefficient way of achieving this functionality. The hope is that if this is any + # good at all, this can be made more efficient by performing a single conv pass with multiple masks. + weighted_convs = [self._conv_transpose_forward(input, self.weight * scale + shift, output_size) + for scale, shift in zip(self.weight_scales, self.shifts)] + weighted_convs = torch.stack(weighted_convs, dim=-1) + + needed_mask = weighted_convs.shape[-2] + assert needed_mask in masks.keys() + + return index_2d(weighted_convs, masks[needed_mask]) diff --git a/codes/models/vqvae/weighted_conv_vqvae.py b/codes/models/vqvae/weighted_conv_vqvae.py new file mode 100644 index 00000000..9452ed26 --- /dev/null +++ b/codes/models/vqvae/weighted_conv_vqvae.py @@ -0,0 +1,267 @@ +# Copyright 2018 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import torch +from torch import nn +from torch.nn import functional as F + +import torch.distributed as distributed + +from models.vqvae.scaled_weight_conv import ScaledWeightConv, ScaledWeightConvTranspose +from trainer.networks import register_model +from utils.util import checkpoint, opt_get + + +class Quantize(nn.Module): + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + + embed = torch.randn(dim, n_embed) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + def forward(self, input): + flatten = input.reshape(-1, self.dim) + dist = ( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ self.embed + + self.embed.pow(2).sum(0, keepdim=True) + ) + _, embed_ind = (-dist).max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = self.embed_code(embed_ind) + + if self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + + if distributed.is_initialized() and distributed.get_world_size() > 1: + distributed.all_reduce(embed_onehot_sum) + distributed.all_reduce(embed_sum) + + self.cluster_size.data.mul_(self.decay).add_( + embed_onehot_sum, alpha=1 - self.decay + ) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + n = self.cluster_size.sum() + cluster_size = ( + (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + diff = (quantize.detach() - input).pow(2).mean() + quantize = input + (quantize - input).detach() + + return quantize, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.transpose(0, 1)) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, channel, breadth): + super().__init__() + + self.conv = nn.ModuleList([ + nn.ReLU(inplace=True), + ScaledWeightConv(in_channel, channel, 3, padding=1, breadth=breadth), + nn.ReLU(inplace=True), + ScaledWeightConv(channel, in_channel, 1, breadth=breadth), + ]) + + def forward(self, input, masks): + out = input + for m in self.conv: + if isinstance(m, ScaledWeightConv): + out = m(out, masks) + else: + out = m(out) + out += input + + return out + + +class Encoder(nn.Module): + def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride, breadth): + super().__init__() + + if stride == 4: + blocks = [ + ScaledWeightConv(in_channel, channel // 2, 4, stride=2, padding=1, breadth=breadth), + nn.ReLU(inplace=True), + ScaledWeightConv(channel // 2, channel, 4, stride=2, padding=1, breadth=breadth), + nn.ReLU(inplace=True), + ScaledWeightConv(channel, channel, 3, padding=1, breadth=breadth), + ] + + elif stride == 2: + blocks = [ + ScaledWeightConv(in_channel, channel // 2, 4, stride=2, padding=1, breadth=breadth), + nn.ReLU(inplace=True), + ScaledWeightConv(channel // 2, channel, 3, padding=1, breadth=breadth), + ] + + for i in range(n_res_block): + blocks.append(ResBlock(channel, n_res_channel, breadth=breadth)) + + blocks.append(nn.ReLU(inplace=True)) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, input): + for block in self.blocks: + if isinstance(block, ScaledWeightConv) or isinstance(block, ResBlock): + input = block(input, self.masks) + else: + input = block(input) + return input + + +class Decoder(nn.Module): + def __init__( + self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride, breadth + ): + super().__init__() + + blocks = [ScaledWeightConv(in_channel, channel, 3, padding=1, breadth=breadth)] + + for i in range(n_res_block): + blocks.append(ResBlock(channel, n_res_channel, breadth=breadth)) + + blocks.append(nn.ReLU(inplace=True)) + + if stride == 4: + blocks.extend( + [ + ScaledWeightConvTranspose(channel, channel // 2, 4, stride=2, padding=1, breadth=breadth), + nn.ReLU(inplace=True), + ScaledWeightConvTranspose( + channel // 2, out_channel, 4, stride=2, padding=1, breadth=breadth + ), + ] + ) + + elif stride == 2: + blocks.append( + ScaledWeightConvTranspose(channel, out_channel, 4, stride=2, padding=1, breadth=breadth) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, input): + for block in self.blocks: + if isinstance(block, ScaledWeightConvTranspose) or isinstance(block, ResBlock) \ + or isinstance(block, ScaledWeightConv): + input = block(input, self.masks) + else: + input = block(input) + return input + + +class VQVAE(nn.Module): + def __init__( + self, + in_channel=3, + channel=128, + n_res_block=2, + n_res_channel=32, + codebook_dim=64, + codebook_size=512, + breadth=8, + decay=0.99, + ): + super().__init__() + + self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth) + self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth) + self.quantize_conv_t = ScaledWeightConv(channel, codebook_dim, 1, breadth=breadth) + self.quantize_t = Quantize(codebook_dim, codebook_size) + self.dec_t = Decoder( + codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2, breadth=breadth + ) + self.quantize_conv_b = ScaledWeightConv(codebook_dim + channel, codebook_dim, 1, breadth=breadth) + self.quantize_b = Quantize(codebook_dim, codebook_size) + self.upsample_t = ScaledWeightConvTranspose( + codebook_dim, codebook_dim, 4, stride=2, padding=1, breadth=breadth + ) + self.dec = Decoder( + codebook_dim + codebook_dim, + in_channel, + channel, + n_res_block, + n_res_channel, + stride=4, + breadth=breadth + ) + + def forward(self, input, masks): + # This awkward injection point is necessary to enable checkpointing to work. + for m in [self.enc_b, self.enc_t, self.dec_t, self.dec]: + m.masks = masks + + quant_t, quant_b, diff, _, _ = self.encode(input, masks) + dec = self.decode(quant_t, quant_b, masks) + + return dec, diff + + def encode(self, input, masks): + enc_b = checkpoint(self.enc_b, input) + enc_t = checkpoint(self.enc_t, enc_b) + + quant_t = self.quantize_conv_t(enc_t, masks).permute(0, 2, 3, 1) + quant_t, diff_t, id_t = self.quantize_t(quant_t) + quant_t = quant_t.permute(0, 3, 1, 2) + diff_t = diff_t.unsqueeze(0) + + dec_t = checkpoint(self.dec_t, quant_t) + enc_b = torch.cat([dec_t, enc_b], 1) + + quant_b = self.quantize_conv_b(enc_b, masks).permute(0, 2, 3, 1) + quant_b, diff_b, id_b = self.quantize_b(quant_b) + quant_b = quant_b.permute(0, 3, 1, 2) + diff_b = diff_b.unsqueeze(0) + + return quant_t, quant_b, diff_t + diff_b, id_t, id_b + + def decode(self, quant_t, quant_b, masks): + upsample_t = self.upsample_t(quant_t, masks) + quant = torch.cat([upsample_t, quant_b], 1) + dec = checkpoint(self.dec, quant) + + return dec + + def decode_code(self, code_t, code_b): + quant_t = self.quantize_t.embed_code(code_t) + quant_t = quant_t.permute(0, 3, 1, 2) + quant_b = self.quantize_b.embed_code(code_b) + quant_b = quant_b.permute(0, 3, 1, 2) + + dec = self.decode(quant_t, quant_b, masks) + + return dec + + +@register_model +def register_weighted_vqvae(opt_net, opt): + kw = opt_get(opt_net, ['kwargs'], {}) + return VQVAE(**kw) diff --git a/codes/train.py b/codes/train.py index e57e4152..9117131f 100644 --- a/codes/train.py +++ b/codes/train.py @@ -295,7 +295,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_stylesr.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_imgset_vqvae_stage1/train_imgset_vqvae_stage1_5.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/utils/kmeans.py b/codes/utils/kmeans.py index f5692bb2..0d21f3a4 100644 --- a/codes/utils/kmeans.py +++ b/codes/utils/kmeans.py @@ -115,8 +115,7 @@ def kmeans( def kmeans_predict( X, cluster_centers, - distance='euclidean', - device=torch.device('cpu') + distance='euclidean' ): """ predict using cluster centers @@ -126,8 +125,6 @@ def kmeans_predict( :param device: (torch.device) device [default: 'cpu'] :return: (torch.tensor) cluster ids """ - print(f'predicting on {device}..') - if distance == 'euclidean': pairwise_distance_function = pairwise_distance elif distance == 'cosine': @@ -135,22 +132,13 @@ def kmeans_predict( else: raise NotImplementedError - # convert to float - X = X.float() - - # transfer to device - X = X.to(device) - dis = pairwise_distance_function(X, cluster_centers) choice_cluster = torch.argmin(dis, dim=1) - return choice_cluster.cpu() + return choice_cluster -def pairwise_distance(data1, data2, device=torch.device('cpu')): - # transfer to device - data1, data2 = data1.to(device), data2.to(device) - +def pairwise_distance(data1, data2): # N*1*M A = data1.unsqueeze(dim=1) @@ -163,10 +151,7 @@ def pairwise_distance(data1, data2, device=torch.device('cpu')): return dis -def pairwise_cosine(data1, data2, device=torch.device('cpu')): - # transfer to device - data1, data2 = data1.to(device), data2.to(device) - +def pairwise_cosine(data1, data2): # N*1*M A = data1.unsqueeze(dim=1)