Hard Routing mods
- Turns out my custom convolution was RIDDLED with backwards bugs, which is why the existing implementation wasn't working so well. - Implements the switch logic from both Mixture of Experts and Switch Transformers for testing purposes.
This commit is contained in:
parent
29c1c3bede
commit
0dca36946f
|
@ -9,35 +9,81 @@ import torch.nn.functional as F
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from trainer.losses import ConfigurableLoss
|
||||||
|
|
||||||
|
|
||||||
|
def SwitchedConvRoutingNormal(input, selector, weight, bias, stride=1):
|
||||||
|
convs = []
|
||||||
|
b, s, h, w = selector.shape
|
||||||
|
for sel in range(s):
|
||||||
|
convs.append(F.conv2d(input, weight[:, :, sel, :, :], bias, stride=stride, padding=weight.shape[-1] // 2))
|
||||||
|
output = torch.stack(convs, dim=1) * selector.unsqueeze(dim=2)
|
||||||
|
return output.sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
class SwitchedConvHardRoutingFunction(torch.autograd.Function):
|
class SwitchedConvHardRoutingFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, selector, weight, bias, stride=1):
|
def forward(ctx, input, selector, weight, bias, stride=1):
|
||||||
# Build hard attention mask from selector input
|
# Build hard attention mask from selector input
|
||||||
b, s, h, w = selector.shape
|
b, s, h, w = selector.shape
|
||||||
selector_mask = (selector.max(dim=1, keepdim=True)[0].repeat(1,s,1,1) == selector).float()
|
|
||||||
mask = selector_mask.argmax(dim=1).int()
|
|
||||||
|
|
||||||
# Compute the convolution using the mask.
|
mask = selector.argmax(dim=1).int()
|
||||||
outputs = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride)
|
output = switched_conv_cuda_naive.forward(input, mask, weight, bias, stride)
|
||||||
|
|
||||||
ctx.stride = stride
|
ctx.stride = stride
|
||||||
ctx.breadth = s
|
ctx.breadth = s
|
||||||
ctx.save_for_backward(*[input, mask, weight, bias])
|
ctx.save_for_backward(*[input, mask, weight, bias])
|
||||||
return outputs
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad):
|
def backward(ctx, grad):
|
||||||
input, mask, weight, bias = ctx.saved_tensors
|
input, mask, weight, bias = ctx.saved_tensors
|
||||||
|
grad, grad_sel, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
|
||||||
# Get the grads for the convolution.
|
|
||||||
grad, grad_w, grad_b = switched_conv_cuda_naive.backward(input, grad.contiguous(), mask, weight, bias, ctx.stride)
|
|
||||||
|
|
||||||
# Get the selector grads
|
|
||||||
selector_mask = torch.eye(ctx.breadth, device=input.device)[mask.long()].permute(0,3,1,2).unsqueeze(2) # Note that this is not necessarily equivalent to the selector_mask from above, because under certain circumstances, two values could take on the value '1' in the above instance, whereas this is a true one-hot representation.
|
|
||||||
grad_sel = ((grad * input).unsqueeze(1) * selector_mask).sum(2)
|
|
||||||
return grad, grad_sel, grad_w, grad_b, None
|
return grad, grad_sel, grad_w, grad_b, None
|
||||||
|
|
||||||
|
|
||||||
|
# Implements KeepTopK where k=1 from mixture of experts paper.
|
||||||
|
class KeepTop1(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input):
|
||||||
|
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
|
||||||
|
input[mask != 1] = -float('inf')
|
||||||
|
ctx.save_for_backward(mask)
|
||||||
|
return input
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad):
|
||||||
|
import pydevd
|
||||||
|
pydevd.settrace(suspend=False, trace_only_current_thread=True)
|
||||||
|
mask = ctx.saved_tensors
|
||||||
|
grad_input = grad.clone()
|
||||||
|
grad_input[mask != 1] = 0
|
||||||
|
return grad_input
|
||||||
|
|
||||||
|
|
||||||
|
class RouteTop1(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input):
|
||||||
|
mask = torch.nn.functional.one_hot(input.argmax(dim=1), num_classes=input.shape[1]).permute(0,3,1,2)
|
||||||
|
out = torch.ones_like(input)
|
||||||
|
out[mask != 1] = 0
|
||||||
|
ctx.save_for_backward(mask, input.clone())
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad):
|
||||||
|
# Enable breakpoints in this function: (Comment out if not debugging)
|
||||||
|
#import pydevd
|
||||||
|
#pydevd.settrace(suspend=False, trace_only_current_thread=True)
|
||||||
|
|
||||||
|
mask, input = ctx.saved_tensors
|
||||||
|
input[mask != 1] = 1
|
||||||
|
grad_input = grad.clone()
|
||||||
|
grad_input[mask != 1] = 0
|
||||||
|
grad_input_n = grad_input / input # Above, we made everything either a zero or a one. Unscale the ones by dividing by the unmasked inputs.
|
||||||
|
return grad_input_n
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of
|
SwitchNorm is meant to be applied against the Softmax output of an switching function across a large set of
|
||||||
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
|
switch computations. It is meant to promote an equal distribution of switch weights by decreasing the magnitude
|
||||||
|
@ -109,6 +155,107 @@ class SwitchNorm(nn.Module):
|
||||||
return x / x.sum(dim=1, keepdim=True)
|
return x / x.sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
class MixtureOfExperts2dRouter(nn.Module):
|
||||||
|
def __init__(self, num_experts):
|
||||||
|
super().__init__()
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.wnoise = nn.Parameter(torch.zeros(1,num_experts,1,1))
|
||||||
|
self.wg = nn.Parameter(torch.zeros(1,num_experts,1,1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
wg = x * self.wg
|
||||||
|
wnoise = nn.functional.softplus(x * self.wnoise)
|
||||||
|
H = wg + torch.randn_like(x) * wnoise
|
||||||
|
|
||||||
|
# Produce the load-balancing loss.
|
||||||
|
eye = torch.eye(self.num_experts, device=x.device).view(1,self.num_experts,self.num_experts,1,1)
|
||||||
|
mask=torch.abs(1-eye)
|
||||||
|
b,c,h,w=H.shape
|
||||||
|
ninf = torch.zeros_like(eye)
|
||||||
|
ninf[eye==1] = -float('inf')
|
||||||
|
H_masked=H.view(b,c,1,h,w)*mask+ninf # ninf is necessary because otherwise torch.max() will not pick up negative numbered maxes.
|
||||||
|
max_excluding=torch.max(H_masked,dim=2)[0]
|
||||||
|
|
||||||
|
# load_loss and G are stored as local members to facilitate their use by hard routing regularization losses.
|
||||||
|
# this is a risky op - it can easily result in memory leakage. Clients *must* use self.reset() below.
|
||||||
|
self.load_loss = torch.erf((wg - max_excluding)/wnoise)
|
||||||
|
#self.G = nn.functional.softmax(KeepTop1.apply(H), dim=1) The paper proposes this equation, but performing a softmax on a Top-1 per the paper results in zero gradients into H, so:
|
||||||
|
self.G = RouteTop1.apply(nn.functional.softmax(H, dim=1)) # This variant can route gradients downstream.
|
||||||
|
|
||||||
|
return self.G
|
||||||
|
|
||||||
|
# Retrieve the locally stored loss values and delete them from membership (so as to not waste memory)
|
||||||
|
def reset(self):
|
||||||
|
G, load = self.G, self.load_loss
|
||||||
|
del self.G
|
||||||
|
del self.load_loss
|
||||||
|
return G, load
|
||||||
|
|
||||||
|
|
||||||
|
# Loss that finds instances of MixtureOfExperts2dRouter in the given network and extracts their custom losses.
|
||||||
|
class MixtureOfExpertsLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.routers = [] # This is filled in during the first forward() pass and cached from there.
|
||||||
|
self.first_forward_encountered = False
|
||||||
|
self.load_weight = opt['load_weight']
|
||||||
|
self.importance_weight = opt['importance_weight']
|
||||||
|
|
||||||
|
def forward(self, net, state):
|
||||||
|
if not self.first_forward_encountered:
|
||||||
|
for m in net.modules():
|
||||||
|
if isinstance(m, MixtureOfExperts2dRouter):
|
||||||
|
self.routers.append(m)
|
||||||
|
self.first_forward_encountered = True
|
||||||
|
|
||||||
|
l_importance = 0
|
||||||
|
l_load = 0
|
||||||
|
for r in self.routers:
|
||||||
|
G, L = r.reset()
|
||||||
|
l_importance += G.var().square()
|
||||||
|
l_load += L.var().square()
|
||||||
|
return l_importance * self.importance_weight + l_load * self.load_weight
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchTransformersLoadBalancer(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = SwitchNorm(8, accumulator_size=256)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.soft = self.norm(nn.functional.softmax(x, dim=1))
|
||||||
|
self.hard = RouteTop1.apply(self.soft) # This variant can route gradients downstream.
|
||||||
|
return self.hard
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
soft, hard = self.soft, self.hard
|
||||||
|
del self.soft, self.hard
|
||||||
|
return soft, hard
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchTransformersLoadBalancingLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.routers = [] # This is filled in during the first forward() pass and cached from there.
|
||||||
|
self.first_forward_encountered = False
|
||||||
|
|
||||||
|
def forward(self, net, state):
|
||||||
|
if not self.first_forward_encountered:
|
||||||
|
for m in net.modules():
|
||||||
|
if isinstance(m, SwitchTransformersLoadBalancer):
|
||||||
|
self.routers.append(m)
|
||||||
|
self.first_forward_encountered = True
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
for r in self.routers:
|
||||||
|
soft, hard = r.reset()
|
||||||
|
N = hard.shape[1]
|
||||||
|
h_mean = hard.mean(dim=[0,2,3])
|
||||||
|
s_mean = soft.mean(dim=[0,2,3])
|
||||||
|
loss += torch.dot(h_mean, s_mean) * N
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class SwitchedConvHardRouting(nn.Module):
|
class SwitchedConvHardRouting(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_c,
|
in_c,
|
||||||
|
@ -120,8 +267,7 @@ class SwitchedConvHardRouting(nn.Module):
|
||||||
dropout_rate=0.0,
|
dropout_rate=0.0,
|
||||||
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
|
include_coupler: bool = False, # A 'coupler' is a latent converter which can make any bxcxhxw tensor a compatible switchedconv selector by performing a linear 1x1 conv, softmax and interpolate.
|
||||||
coupler_mode: str = 'standard',
|
coupler_mode: str = 'standard',
|
||||||
coupler_dim_in: int = 0,
|
coupler_dim_in: int = 0):
|
||||||
switch_norm: bool = True):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_c
|
self.in_channels = in_c
|
||||||
self.out_channels = out_c
|
self.out_channels = out_c
|
||||||
|
@ -130,14 +276,10 @@ class SwitchedConvHardRouting(nn.Module):
|
||||||
self.has_bias = bias
|
self.has_bias = bias
|
||||||
self.breadth = breadth
|
self.breadth = breadth
|
||||||
self.dropout_rate = dropout_rate
|
self.dropout_rate = dropout_rate
|
||||||
if switch_norm:
|
|
||||||
self.switch_norm = SwitchNorm(breadth, accumulator_size=512)
|
|
||||||
else:
|
|
||||||
self.switch_norm = None
|
|
||||||
|
|
||||||
if include_coupler:
|
if include_coupler:
|
||||||
if coupler_mode == 'standard':
|
if coupler_mode == 'standard':
|
||||||
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1)
|
self.coupler = Conv2d(coupler_dim_in, breadth, kernel_size=1, stride=self.stride)
|
||||||
elif coupler_mode == 'lambda':
|
elif coupler_mode == 'lambda':
|
||||||
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
|
self.coupler = nn.Sequential(nn.Conv2d(coupler_dim_in, coupler_dim_in, 1),
|
||||||
nn.BatchNorm2d(coupler_dim_in),
|
nn.BatchNorm2d(coupler_dim_in),
|
||||||
|
@ -145,9 +287,11 @@ class SwitchedConvHardRouting(nn.Module):
|
||||||
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
|
LambdaLayer(dim=coupler_dim_in, dim_out=breadth, r=23, dim_k=16, heads=2, dim_u=1),
|
||||||
nn.BatchNorm2d(breadth),
|
nn.BatchNorm2d(breadth),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
Conv2d(breadth, breadth, 1))
|
Conv2d(breadth, breadth, 1, stride=self.stride))
|
||||||
else:
|
else:
|
||||||
self.coupler = None
|
self.coupler = None
|
||||||
|
#self.gate = MixtureOfExperts2dRouter(breadth)
|
||||||
|
self.gate = SwitchTransformersLoadBalancer()
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz))
|
self.weight = nn.Parameter(torch.empty(out_c, in_c, breadth, kernel_sz, kernel_sz))
|
||||||
if bias:
|
if bias:
|
||||||
|
@ -175,14 +319,10 @@ class SwitchedConvHardRouting(nn.Module):
|
||||||
# If a coupler was specified, run that to convert selector into a softmax distribution.
|
# If a coupler was specified, run that to convert selector into a softmax distribution.
|
||||||
if self.coupler:
|
if self.coupler:
|
||||||
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
|
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
|
||||||
selector = input.detach()
|
selector = input
|
||||||
selector = F.softmax(self.coupler(selector), dim=1)
|
selector = self.coupler(selector)
|
||||||
assert selector is not None
|
assert selector is not None
|
||||||
|
|
||||||
# Perform normalization on the selector if applicable.
|
|
||||||
if self.switch_norm:
|
|
||||||
selector = self.switch_norm(selector)
|
|
||||||
|
|
||||||
# Apply dropout at the batch level per kernel.
|
# Apply dropout at the batch level per kernel.
|
||||||
if self.training and self.dropout_rate > 0:
|
if self.training and self.dropout_rate > 0:
|
||||||
b, c, h, w = selector.shape
|
b, c, h, w = selector.shape
|
||||||
|
@ -192,11 +332,18 @@ class SwitchedConvHardRouting(nn.Module):
|
||||||
drop = drop.logical_or(fix_blank)
|
drop = drop.logical_or(fix_blank)
|
||||||
selector = drop * selector
|
selector = drop * selector
|
||||||
|
|
||||||
|
selector = self.gate(selector)
|
||||||
|
|
||||||
# Debugging variables
|
# Debugging variables
|
||||||
self.last_select = selector.detach().clone()
|
self.last_select = selector.detach().clone()
|
||||||
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1)
|
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,self.breadth,1,1) == selector).float().argmax(dim=1)
|
||||||
|
|
||||||
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
|
if False:
|
||||||
|
# This is a custom CUDA implementation which should be faster and less memory intensive (once completed).
|
||||||
|
return SwitchedConvHardRoutingFunction.apply(input, selector, self.weight, self.bias, self.stride)
|
||||||
|
else:
|
||||||
|
# This composes the switching functionality using raw Torch, which basically consists of computing each of <breadth> convs separately and combining them.
|
||||||
|
return SwitchedConvRoutingNormal(input, selector, self.weight, self.bias, self.stride)
|
||||||
|
|
||||||
|
|
||||||
# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them
|
# Given a state_dict and the module that that sd belongs to, strips out all Conv2d.weight parameters and replaces them
|
||||||
|
@ -213,7 +360,11 @@ def convert_conv_net_state_dict_to_switched_conv(module, switch_breadth, ignore_
|
||||||
continue
|
continue
|
||||||
if ignored:
|
if ignored:
|
||||||
continue
|
continue
|
||||||
state_dict[f'{name}.weight'] = state_dict[f'{name}.weight'].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
|
if name == '':
|
||||||
|
key = 'weight'
|
||||||
|
else:
|
||||||
|
key = f'{name}.weight'
|
||||||
|
state_dict[key] = state_dict[key].unsqueeze(2).repeat(1,1,switch_breadth,1,1)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from utils.util import checkpoint, opt_get
|
||||||
class UpsampleConv(nn.Module):
|
class UpsampleConv(nn.Module):
|
||||||
def __init__(self, in_filters, out_filters, breadth, kernel_size, padding):
|
def __init__(self, in_filters, out_filters, breadth, kernel_size, padding):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_filters, dropout_rate=0.4)
|
self.conv = SwitchedConvHardRouting(in_filters, out_filters, kernel_size, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_filters, dropout_rate=0.4)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||||
|
@ -104,16 +104,16 @@ class Encoder(nn.Module):
|
||||||
blocks = [
|
blocks = [
|
||||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4),
|
SwitchedConvHardRouting(channel // 2, channel, 5, breadth, stride=2, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel, dropout_rate=0.4),
|
SwitchedConvHardRouting(channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel, dropout_rate=0.4),
|
||||||
]
|
]
|
||||||
|
|
||||||
elif stride == 2:
|
elif stride == 2:
|
||||||
blocks = [
|
blocks = [
|
||||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=channel // 2, dropout_rate=0.4),
|
SwitchedConvHardRouting(channel // 2, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=channel // 2, dropout_rate=0.4),
|
||||||
]
|
]
|
||||||
|
|
||||||
for i in range(n_res_block):
|
for i in range(n_res_block):
|
||||||
|
@ -133,7 +133,7 @@ class Decoder(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='lambda', coupler_dim_in=in_channel, dropout_rate=0.4)]
|
blocks = [SwitchedConvHardRouting(in_channel, channel, 3, breadth, include_coupler=True, coupler_mode='standard', coupler_dim_in=in_channel, dropout_rate=0.4)]
|
||||||
|
|
||||||
for i in range(n_res_block):
|
for i in range(n_res_block):
|
||||||
blocks.append(ResBlock(channel, n_res_channel, breadth))
|
blocks.append(ResBlock(channel, n_res_channel, breadth))
|
||||||
|
|
|
@ -1,78 +1,49 @@
|
||||||
import os.path as osp
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
import utils
|
from torch.utils.data import DataLoader
|
||||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
|
||||||
from trainer.networks import define_F
|
from data.single_image_dataset import SingleImageDataset
|
||||||
from utils import options as option
|
|
||||||
import utils.util as util
|
|
||||||
from data import create_dataset, create_dataloader
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from models.vqvae.vqvae_no_conv_transpose import VQVAE
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
bin_path = "f:\\binned"
|
bin_path = "f:\\binned"
|
||||||
good_path = "f:\\good"
|
good_path = "f:\\good"
|
||||||
os.makedirs(bin_path, exist_ok=True)
|
os.makedirs(bin_path, exist_ok=True)
|
||||||
os.makedirs(good_path, exist_ok=True)
|
os.makedirs(good_path, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/generator_filter.yml')
|
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
|
||||||
opt = option.dict_to_nonedict(opt)
|
|
||||||
opt['dist'] = False
|
|
||||||
|
|
||||||
util.mkdirs(
|
model = VQVAE().cuda()
|
||||||
(path for key, path in opt['path'].items()
|
model.load_state_dict(torch.load('../experiments/nvqvae_imgset.pth'))
|
||||||
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
|
ds = SingleImageDataset({
|
||||||
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
|
'name': 'amalgam',
|
||||||
screen=True, tofile=True)
|
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_v5'],
|
||||||
logger = logging.getLogger('base')
|
'weights': [1],
|
||||||
logger.info(option.dict2str(opt))
|
'target_size': 128,
|
||||||
|
'force_multiple': 32,
|
||||||
|
'scale': 1,
|
||||||
|
'eval': False
|
||||||
|
})
|
||||||
|
dl = DataLoader(ds, batch_size=256, num_workers=1)
|
||||||
|
|
||||||
#### Create test dataset and dataloader
|
means = []
|
||||||
test_loaders = []
|
model.eval()
|
||||||
for phase, dataset_opt in sorted(opt['datasets'].items()):
|
with torch.no_grad():
|
||||||
test_set = create_dataset(dataset_opt)
|
for i, data in enumerate(tqdm(dl)):
|
||||||
test_loader = create_dataloader(test_set, dataset_opt, opt=opt)
|
hq = data['hq'].cuda()
|
||||||
logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
|
gen = model(hq)[0]
|
||||||
test_loaders.append(test_loader)
|
l2 = torch.mean(torch.square(hq - gen), dim=[1,2,3])
|
||||||
|
for b in range(len(l2)):
|
||||||
|
if l2[b] > .0004:
|
||||||
|
shutil.copy(data['GT_path'][b], good_path)
|
||||||
|
#else:
|
||||||
|
# shutil.copy(data['GT_path'][b], bin_path)
|
||||||
|
|
||||||
model = ExtensibleTrainer(opt)
|
|
||||||
utils.util.loaded_options = opt
|
|
||||||
fea_loss = 0
|
|
||||||
for test_loader in test_loaders:
|
|
||||||
test_set_name = test_loader.dataset.opt['name']
|
|
||||||
logger.info('\nTesting [{:s}]...'.format(test_set_name))
|
|
||||||
test_start_time = time.time()
|
|
||||||
dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
|
|
||||||
util.mkdir(dataset_dir)
|
|
||||||
netF = define_F(which_model='vgg').to(model.env['device'])
|
|
||||||
|
|
||||||
tq = tqdm(test_loader)
|
#means.append(l2.cpu())
|
||||||
removed = 0
|
#if i % 10 == 0:
|
||||||
means = []
|
# print(torch.stack(means, dim=0).mean())
|
||||||
for data in tq:
|
|
||||||
model.feed_data(data, need_GT=True)
|
|
||||||
model.test()
|
|
||||||
gen = model.eval_state['gen'][0].to(model.env['device'])
|
|
||||||
feagen = netF(gen)
|
|
||||||
feareal = netF(data['hq'].to(model.env['device']))
|
|
||||||
losses = torch.sum(torch.abs(feareal - feagen), dim=(1,2,3))
|
|
||||||
means.append(torch.mean(losses).item())
|
|
||||||
#print(sum(means)/len(means), torch.mean(losses), torch.max(losses), torch.min(losses))
|
|
||||||
for i in range(losses.shape[0]):
|
|
||||||
if losses[i] < 25000:
|
|
||||||
os.remove(data['GT_path'][i])
|
|
||||||
removed += 1
|
|
||||||
#imname = osp.basename(data['GT_path'][i])
|
|
||||||
#if losses[i] < 25000:
|
|
||||||
# torchvision.utils.save_image(data['hq'][i], osp.join(bin_path, imname))
|
|
||||||
|
|
||||||
print("Removed %i/%i images" % (removed, len(test_set)))
|
|
||||||
|
|
|
@ -47,6 +47,12 @@ def create_loss(opt_loss, env):
|
||||||
return RecurrentLoss(opt_loss, env)
|
return RecurrentLoss(opt_loss, env)
|
||||||
elif type == 'for_element':
|
elif type == 'for_element':
|
||||||
return ForElementLoss(opt_loss, env)
|
return ForElementLoss(opt_loss, env)
|
||||||
|
elif type == 'mixture_of_experts':
|
||||||
|
from models.switched_conv_hard_routing import MixtureOfExpertsLoss
|
||||||
|
return MixtureOfExpertsLoss(opt_loss, env)
|
||||||
|
elif type == 'switch_transformer_balance':
|
||||||
|
from models.switched_conv_hard_routing import SwitchTransformersLoadBalancingLoss
|
||||||
|
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user