diff --git a/codes/train.py b/codes/train.py index 42027511..d2424ed8 100644 --- a/codes/train.py +++ b/codes/train.py @@ -94,6 +94,7 @@ class Trainer: util.set_random_seed(seed) torch.backends.cudnn.benchmark = opt_get(opt, ['cuda_benchmarking_enabled'], True) + torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cudnn.deterministic = True if opt_get(opt, ['anomaly_detection'], False): torch.autograd.set_detect_anomaly(True)