force tf32 to be enabled (torch 1.12 disables it)

This commit is contained in:
James Betker 2022-07-16 13:59:07 -06:00
parent 438dcaccc5
commit a073fbfcb8

View File

@ -94,6 +94,7 @@ class Trainer:
util.set_random_seed(seed)
torch.backends.cudnn.benchmark = opt_get(opt, ['cuda_benchmarking_enabled'], True)
torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.deterministic = True
if opt_get(opt, ['anomaly_detection'], False):
torch.autograd.set_detect_anomaly(True)