Add find_faulty_files.py

This commit is contained in:
James Betker 2021-08-25 18:00:43 -06:00
parent 08b33c8e3a
commit 909754cc27
4 changed files with 93 additions and 6 deletions

View File

@ -211,7 +211,7 @@ class DiscreteVAE(nn.Module):
out = self.decode(codes) out = self.decode(codes)
# reconstruction loss # reconstruction loss
recon_loss = self.loss_fn(img, out) recon_loss = self.loss_fn(img, out, reduction='none')
# This is so we can debug the distribution of codes being learned. # This is so we can debug the distribution of codes being learned.
if self.record_codes and self.internal_step % 50 == 0: if self.record_codes and self.internal_step % 50 == 0:
@ -236,7 +236,8 @@ if __name__ == '__main__':
#v = DiscreteVAE() #v = DiscreteVAE()
#o=v(torch.randn(1,3,256,256)) #o=v(torch.randn(1,3,256,256))
#print(o.shape) #print(o.shape)
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096, hidden_dim=256, stride=4, num_resnet_blocks=1, kernel_size=5, num_layers=5, use_transposed_convs=False) v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096,
hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, use_transposed_convs=False)
#v.eval() #v.eval()
o=v(torch.randn(1,1,4096)) o=v(torch.randn(1,80,256))
print(o[-1].shape) print(o[-1].shape)

View File

@ -0,0 +1,86 @@
import os.path as osp
import logging
import random
import time
import argparse
from collections import OrderedDict
import utils
import utils.options as option
import utils.util as util
from trainer.ExtensibleTrainer import ExtensibleTrainer
from data import create_dataset, create_dataloader
from tqdm import tqdm
import torch
import numpy as np
current_batch = None
class LossWrapper:
def __init__(self, lwrap):
self.lwrap = lwrap
self.opt = lwrap.opt
def is_stateful(self):
return self.lwrap.is_stateful()
def extra_metrics(self):
return self.lwrap.extra_metrics()
def clear_metrics(self):
self.lwrap.clear_metrics()
def __call__(self, m, state):
global current_batch
val = state[self.lwrap.key]
assert val.shape[0] == len(current_batch['path'])
val = val.view(val.shape[0], -1)
val = val.mean(dim=1)
errant = torch.nonzero(val > .5)
for i in errant:
print(f"ERRANT FOUND: {val[i]} path: {current_batch['path'][i]}")
return self.lwrap(m, state)
# Script that builds an ExtensibleTrainer, then a pertinent loss with the above LossWrapper. The
# LossWrapper then croaks when it finds an input that produces a divergent loss
if __name__ == "__main__":
# Set seeds
torch.manual_seed(5555)
random.seed(5555)
np.random.seed(5555)
#### options
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_lrdvae_audio_clips.yml')
opt = option.parse(parser.parse_args().opt, is_train=True)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt
util.mkdirs(
(path for key, path in opt['path'].items()
if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
screen=True, tofile=True)
logger = logging.getLogger('base')
logger.info(option.dict2str(opt))
#### Create test dataset and dataloader
dataset = create_dataset(opt['datasets']['train'])
dataloader = create_dataloader(dataset, opt['datasets']['train'])
logger.info('Number of test images in [{:s}]: {:d}'.format(opt['datasets']['train']['name'], len(dataset)))
model = ExtensibleTrainer(opt)
assert len(model.steps) == 1
step = model.steps[0]
step.losses['reconstruction_loss'] = LossWrapper(step.losses['reconstruction_loss'])
for i, data in enumerate(tqdm(dataloader)):
current_batch = data
model.feed_data(data, i)
model.optimize_parameters(i)

View File

@ -38,7 +38,7 @@ def find_registered_model_fns(base_path='models'):
module_iter = pkgutil.walk_packages([base_path]) module_iter = pkgutil.walk_packages([base_path])
for mod in module_iter: for mod in module_iter:
if os.name == 'nt': if os.name == 'nt':
if os.getcwd() not in mod.module_finder.path: if os.path.join(os.getcwd(), base_path) not in mod.module_finder.path:
continue # I have no idea why this is necessary - I think it's a bug in the latest PyWindows release. continue # I have no idea why this is necessary - I think it's a bug in the latest PyWindows release.
if mod.ispkg: if mod.ispkg:
EXCLUSION_LIST = ['flownet2'] EXCLUSION_LIST = ['flownet2']

View File

@ -5,7 +5,7 @@ from torch.optim import Optimizer
class SGDNoBiasMomentum(Optimizer): class SGDNoBiasMomentum(Optimizer):
r""" r"""
Copy of pytorch implementation of SGD with a modification which turns off momentum for params marked Copy of pytorch implementation of SGD with a modification which turns off momentum for params marked
with `is_bn` or `is_bias`. with `is_norm` or `is_bias`.
""" """
def __init__(self, params, lr, momentum=0, dampening=0, def __init__(self, params, lr, momentum=0, dampening=0,
@ -54,7 +54,7 @@ class SGDNoBiasMomentum(Optimizer):
if weight_decay != 0: if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay) d_p = d_p.add(p, alpha=weight_decay)
# **this is the only modification over standard torch.optim.SGD: # **this is the only modification over standard torch.optim.SGD:
is_bn_or_bias = (hasattr(p, 'is_norm') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias) is_bn_or_bias = (hasattr(p, 'is_norm') and p.is_norm) or (hasattr(p, 'is_bias') and p.is_bias)
if not is_bn_or_bias and momentum != 0: if not is_bn_or_bias and momentum != 0:
param_state = self.state[p] param_state = self.state[p]
if 'momentum_buffer' not in param_state: if 'momentum_buffer' not in param_state: