Update find_faulty_files
This commit is contained in:
parent
9191201f05
commit
a66a2bf91b
|
@ -15,6 +15,7 @@ import torch
|
|||
import numpy as np
|
||||
|
||||
current_batch = None
|
||||
output_file = open('find_faulty_files_results.tsv', 'a')
|
||||
|
||||
class LossWrapper:
|
||||
def __init__(self, lwrap):
|
||||
|
@ -32,13 +33,16 @@ class LossWrapper:
|
|||
|
||||
def __call__(self, m, state):
|
||||
global current_batch
|
||||
global output_file
|
||||
val = state[self.lwrap.key]
|
||||
assert val.shape[0] == len(current_batch['path'])
|
||||
val = val.view(val.shape[0], -1)
|
||||
val = val.mean(dim=1)
|
||||
errant = torch.nonzero(val > .5)
|
||||
errant = torch.nonzero(val > 8)
|
||||
for i in errant:
|
||||
print(f"ERRANT FOUND: {val[i]} path: {current_batch['path'][i]}")
|
||||
output_file.write(current_batch['path'][i] + "\n")
|
||||
output_file.flush()
|
||||
return self.lwrap(m, state)
|
||||
|
||||
|
||||
|
@ -54,7 +58,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_lrdvae_audio_clips.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../experiments/clean_with_lrdvae.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=True)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
|
Loading…
Reference in New Issue
Block a user