From 0dca36946f3f240151ba0587d6fde14c9cf99291 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 2 Feb 2021 20:35:58 -0700 Subject: [PATCH] Hard Routing mods - Turns out my custom convolution was RIDDLED with backwards bugs, which is why the existing implementation wasn't working so well. - Implements the switch logic from both Mixture of Experts and Switch Transformers for testing purposes. --- codes/models/switched_conv_hard_routing.py | 207 +++++++++++++++--- ...e_no_conv_transpose_hardswitched_lambda.py | 10 +- codes/scripts/use_generator_as_filter.py | 95 +++----- codes/trainer/losses.py | 6 + 4 files changed, 223 insertions(+), 95 deletions(-) diff --git a/codes/models/switched_conv_hard_routing.py b/codes/models/switched_conv_hard_routing.py index a9320063..2e7ea9a6 100644 --- a/codes/models/switched_conv_hard_routing.py +++ b/codes/models/switched_conv_hard_routing.py @@ -9,35 +9,81 @@ import torch.nn.functional as F from tqdm import tqdm import torch.distributed as dist +from trainer.losses import ConfigurableLoss + + +def SwitchedConvRoutingNormal(input, selector, weight, bias, stride=1): + convs = [] + b, s, h, w = selector.shape + for sel in range(s): + convs.append(F.conv2d(input, weight[:, :, sel, :, :], bias, stride=stride, padding=weight.shape[-1] // 2)) + output = torch.stack(convs, dim=1) * selector.unsqueeze(dim=2) + return output.sum(dim=1) + class SwitchedConvHardRoutingFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, selector, weight, bias, stride=1): # Build hard attention mask from selector input b, s, h, w = selector.shape - selector_mask = (selector.max(dim=1, keepdim=True)[0].repeat(1,s,1,1) == selector).float() - mask = selector_mask.argmax(dim=1).int() - # Compute the convolution using the mask. - outputs = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride) + mask = selector.argmax(dim=1).int() + output = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride) + ctx.stride = stride ctx.breadth = s ctx.save_for_backward(*[input, mask, weight, bias]) - return outputs + return output @staticmethod def backward(ctx, grad): input, mask, weight, bias = ctx.saved_tensors - - # Get the grads for the convolution. - grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride) - - # Get the selector grads - selector_mask = torch.eye(ctx.breadth, device=input.device)[mask.long()].permute(0,3,1,2).unsqueeze(2) # Note that this is not necessarily equivalent to the selector_mask from above, because under certain circumstances, two values could take on the value '1' in the above instance, whereas this is a true one-hot representation. - grad_sel = ((grad * input).unsqueeze(1) * selector_mask).sum(2) + grad, grad_sel, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride) return grad, grad_sel, grad_w, grad_b, None +# Implements KeepTopK where k=1 from mixture of experts paper. +class KeepTop1(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2) + input[mask != 1] = -float('inf') + ctx.save_for_backward(mask) + return input + + @staticmethod + def backward(ctx, grad): + import pydevd + pydevd.settrace(suspend=False, trace_only_current_thread=True) + mask = ctx.saved_tensors + grad_input = grad.clone() + grad_input[mask != 1] = 0 + return grad_input + + +class RouteTop1(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2) + out = torch.ones_like(input) + out[mask != 1] = 0 + ctx.save_for_backward(mask, input.clone()) + return out + + @staticmethod + def backward(ctx, grad): + # Enable breakpoints in this function: (Comment out if not debugging) + #import pydevd + #pydevd.settrace(suspend=False, trace_only_current_thread=True) + + mask, input = ctx.saved_tensors + input[mask != 1] = 1 + grad_input = grad.clone() + grad_input[mask != 1] = 0 + grad_input_n = grad_input / input # Above, we made everything either a zero or a one. Unscale the ones by dividing by the unmasked inputs. + return grad_input_n + + """ SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude @@ -109,6 +155,107 @@ class SwitchNorm(nn.Module): return x / x.sum(dim=1, keepdim=True) +class MixtureOfExperts2dRouter(nn.Module): + def __init__(self, num_experts): + super().__init__() + self.num_experts = num_experts + self.wnoise = nn.Parameter(torch.zeros(1,num_experts,1,1)) + self.wg = nn.Parameter(torch.zeros(1,num_experts,1,1)) + + def forward(self, x): + wg = x * self.wg + wnoise = nn.functional.softplus(x * self.wnoise) + H = wg + torch.randn_like(x) * wnoise + + # Produce the load-balancing loss. + eye = torch.eye(self.num_experts, device=x.device).view(1,self.num_experts,self.num_experts,1,1) + mask=torch.abs(1-eye) + b,c,h,w=H.shape + ninf = torch.zeros_like(eye) + ninf[eye==1] = -float('inf') + H_masked=H.view(b,c,1,h,w)*mask+ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes. + max_excluding=torch.max(H_masked,dim=2)[0] + + # load_loss and G are stored as local members to facilitate their use by hard routing regularization losses. + # this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below. + self.load_loss = torch.erf((wg - max_excluding)/wnoise) + #self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so: + self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream. + + return self.G + + # Retrieve the locally stored loss values and delete them from membership (so as to not waste memory) + def reset(self): + G, load = self.G, self.load_loss + del self.G + del self.load_loss + return G, load + + +# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses. +class MixtureOfExpertsLoss(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.routers = [] # This is filled in during the first forward() pass and cached from there. + self.first_forward_encountered = False + self.load_weight = opt['load_weight'] + self.importance_weight = opt['importance_weight'] + + def forward(self, net, state): + if not self.first_forward_encountered: + for m in net.modules(): + if isinstance(m, MixtureOfExperts2dRouter): + self.routers.append(m) + self.first_forward_encountered = True + + l_importance = 0 + l_load = 0 + for r in self.routers: + G, L = r.reset() + l_importance += G.var().square() + l_load += L.var().square() + return l_importance * self.importance_weight + l_load * self.load_weight + + +class SwitchTransformersLoadBalancer(nn.Module): + def __init__(self): + super().__init__() + self.norm = SwitchNorm(8, accumulator_size=256) + + def forward(self, x): + self.soft = self.norm(nn.functional.softmax(x, dim=1)) + self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream. + return self.hard + + def reset(self): + soft, hard = self.soft, self.hard + del self.soft, self.hard + return soft, hard + + +class SwitchTransformersLoadBalancingLoss(ConfigurableLoss): + def __init__(self, opt, env): + super().__init__(opt, env) + self.routers = [] # This is filled in during the first forward() pass and cached from there. + self.first_forward_encountered = False + + def forward(self, net, state): + if not self.first_forward_encountered: + for m in net.modules(): + if isinstance(m, SwitchTransformersLoadBalancer): + self.routers.append(m) + self.first_forward_encountered = True + + loss = 0 + for r in self.routers: + soft, hard = r.reset() + N = hard.shape[1] + h_mean = hard.mean(dim=[0,2,3]) + s_mean = soft.mean(dim=[0,2,3]) + loss += torch.dot(h_mean, s_mean) * N + return loss + + class SwitchedConvHardRouting(nn.Module): def __init__(self, in_c, @@ -120,8 +267,7 @@ class SwitchedConvHardRouting(nn.Module): dropout_rate=0.0, include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate. coupler_mode: str = 'standard', - coupler_dim_in: int = 0, - switch_norm: bool = True): + coupler_dim_in: int = 0): super().__init__() self.in_channels = in_c self.out_channels = out_c @@ -130,14 +276,10 @@ class SwitchedConvHardRouting(nn.Module): self.has_bias = bias self.breadth = breadth self.dropout_rate = dropout_rate - if switch_norm: - self.switch_norm = SwitchNorm(breadth, accumulator_size=512) - else: - self.switch_norm = None if include_coupler: if coupler_mode == 'standard': - self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1) + self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride) elif coupler_mode == 'lambda': self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1), nn.BatchNorm2d(coupler_dim_in), @@ -145,9 +287,11 @@ class SwitchedConvHardRouting(nn.Module): LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1), nn.BatchNorm2d(breadth), nn.ReLU(), - Conv2d(breadth, breadth, 1)) + Conv2d(breadth, breadth, 1, stride=self.stride)) else: self.coupler = None + #self.gate = MixtureOfExperts2dRouter(breadth) + self.gate = SwitchTransformersLoadBalancer() self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz)) if bias: @@ -175,14 +319,10 @@ class SwitchedConvHardRouting(nn.Module): # If a coupler was specified, run that to convert selector into a softmax distribution. if self.coupler: if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed. - selector = input.detach() - selector = F.softmax(self.coupler(selector), dim=1) + selector = input + selector = self.coupler(selector) assert selector is not None - # Perform normalization on the selector if applicable. - if self.switch_norm: - selector = self.switch_norm(selector) - # Apply dropout at the batch level per kernel. if self.training and self.dropout_rate > 0: b, c, h, w = selector.shape @@ -192,11 +332,18 @@ class SwitchedConvHardRouting(nn.Module): drop = drop.logical_or(fix_blank) selector = drop * selector + selector = self.gate(selector) + # Debugging variables self.last_select = selector.detach().clone() self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1) - return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride) + if False: + # This is a custom CUDA implementation which should be faster and less memory intensive (once completed). + return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride) + else: + # This composes the switching functionality using raw Torch, which basically consists of computing each of convs separately and combining them. + return SwitchedConvRoutingNormal(input, selector, self.weight, self.bias, self.stride) # Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them @@ -213,7 +360,11 @@ def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_ continue if ignored: continue - state_dict[f'{name}.weight'] = state_dict[f'{name}.weight'].unsqueeze(2).repeat(1,1,switch_breadth,1,1) + if name == '': + key = 'weight' + else: + key = f'{name}.weight' + state_dict[key] = state_dict[key].unsqueeze(2).repeat(1,1,switch_breadth,1,1) return state_dict diff --git a/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py b/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py index 06ab7045..9c29648b 100644 --- a/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py +++ b/codes/models/vqvae/vqvae_no_conv_transpose_hardswitched_lambda.py @@ -17,7 +17,7 @@ from utils.util import checkpoint, opt_get class UpsampleConv(nn.Module): def __init__(self, in_filters, out_filters, breadth, kernel_size, padding): super().__init__() - self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters, dropout_rate=0.4) + self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_filters, dropout_rate=0.4) def forward(self, x): up = torch.nn.functional.interpolate(x, scale_factor=2) @@ -104,16 +104,16 @@ class Encoder(nn.Module): blocks = [ nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), nn.ReLU(inplace=True), - SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4), + SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4), nn.ReLU(inplace=True), - SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel, dropout_rate=0.4), + SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel, dropout_rate=0.4), ] elif stride == 2: blocks = [ nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), nn.ReLU(inplace=True), - SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4), + SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4), ] for i in range(n_res_block): @@ -133,7 +133,7 @@ class Decoder(nn.Module): ): super().__init__() - blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.4)] + blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_channel, dropout_rate=0.4)] for i in range(n_res_block): blocks.append(ResBlock(channel, n_res_channel, breadth)) diff --git a/codes/scripts/use_generator_as_filter.py b/codes/scripts/use_generator_as_filter.py index 3b5b0146..f08f7d44 100644 --- a/codes/scripts/use_generator_as_filter.py +++ b/codes/scripts/use_generator_as_filter.py @@ -1,78 +1,49 @@ -import os.path as osp -import logging -import time -import argparse - import os +import shutil -import utils -from trainer.ExtensibleTrainer import ExtensibleTrainer -from trainer.networks import define_F -from utils import options as option -import utils.util as util -from data import create_dataset, create_dataloader +from torch.utils.data import DataLoader + +from data.single_image_dataset import SingleImageDataset from tqdm import tqdm import torch +from models.vqvae.vqvae_no_conv_transpose import VQVAE + if __name__ == "__main__": bin_path = "f:\\binned" good_path = "f:\\good" os.makedirs(bin_path, exist_ok=True) os.makedirs(good_path, exist_ok=True) - torch.backends.cudnn.benchmark = True - parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/generator_filter.yml') - opt = option.parse(parser.parse_args().opt, is_train=False) - opt = option.dict_to_nonedict(opt) - opt['dist'] = False - util.mkdirs( - (path for key, path in opt['path'].items() - if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) - util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, - screen=True, tofile=True) - logger = logging.getLogger('base') - logger.info(option.dict2str(opt)) + model = VQVAE().cuda() + model.load_state_dict(torch.load('../experiments/nvqvae_imgset.pth')) + ds = SingleImageDataset({ + 'name': 'amalgam', + 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_v5'], + 'weights': [1], + 'target_size': 128, + 'force_multiple': 32, + 'scale': 1, + 'eval': False + }) + dl = DataLoader(ds, batch_size=256, num_workers=1) - #### Create test dataset and dataloader - test_loaders = [] - for phase, dataset_opt in sorted(opt['datasets'].items()): - test_set = create_dataset(dataset_opt) - test_loader = create_dataloader(test_set, dataset_opt, opt=opt) - logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) - test_loaders.append(test_loader) + means = [] + model.eval() + with torch.no_grad(): + for i, data in enumerate(tqdm(dl)): + hq = data['hq'].cuda() + gen = model(hq)[0] + l2 = torch.mean(torch.square(hq - gen), dim=[1,2,3]) + for b in range(len(l2)): + if l2[b] > .0004: + shutil.copy(data['GT_path'][b], good_path) + #else: + # shutil.copy(data['GT_path'][b], bin_path) - model = ExtensibleTrainer(opt) - utils.util.loaded_options = opt - fea_loss = 0 - for test_loader in test_loaders: - test_set_name = test_loader.dataset.opt['name'] - logger.info('\nTesting [{:s}]...'.format(test_set_name)) - test_start_time = time.time() - dataset_dir = osp.join(opt['path']['results_root'], test_set_name) - util.mkdir(dataset_dir) - netF = define_F(which_model='vgg').to(model.env['device']) - tq = tqdm(test_loader) - removed = 0 - means = [] - for data in tq: - model.feed_data(data, need_GT=True) - model.test() - gen = model.eval_state['gen'][0].to(model.env['device']) - feagen = netF(gen) - feareal = netF(data['hq'].to(model.env['device'])) - losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3)) - means.append(torch.mean(losses).item()) - #print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses)) - for i in range(losses.shape[0]): - if losses[i] < 25000: - os.remove(data['GT_path'][i]) - removed += 1 - #imname = osp.basename(data['GT_path'][i]) - #if losses[i] < 25000: - # torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname)) - - print("Removed %i/%i images" % (removed, len(test_set))) \ No newline at end of file + #means.append(l2.cpu()) + #if i % 10 == 0: + # print(torch.stack(means, dim=0).mean()) diff --git a/codes/trainer/losses.py b/codes/trainer/losses.py index c2d81147..33aae699 100644 --- a/codes/trainer/losses.py +++ b/codes/trainer/losses.py @@ -47,6 +47,12 @@ def create_loss(opt_loss, env): return RecurrentLoss(opt_loss, env) elif type == 'for_element': return ForElementLoss(opt_loss, env) + elif type == 'mixture_of_experts': + from models.switched_conv_hard_routing import MixtureOfExpertsLoss + return MixtureOfExpertsLoss(opt_loss, env) + elif type == 'switch_transformer_balance': + from models.switched_conv_hard_routing import SwitchTransformersLoadBalancingLoss + return SwitchTransformersLoadBalancingLoss(opt_loss, env) else: raise NotImplementedError