From a073fbfcb8d9c5483b66cf4904a68d782cd6bce1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 16 Jul 2022 13:59:07 -0600 Subject: [PATCH] force tf32 to be enabled (torch 1.12 disables it) --- codes/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/codes/train.py b/codes/train.py index 42027511..d2424ed8 100644 --- a/codes/train.py +++ b/codes/train.py @@ -94,6 +94,7 @@ class Trainer: util.set_random_seed(seed) torch.backends.cudnn.benchmark = opt_get(opt, ['cuda_benchmarking_enabled'], True) + torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cudnn.deterministic = True if opt_get(opt, ['anomaly_detection'], False): torch.autograd.set_detect_anomaly(True)