diff --git a/setup-cuda.bat b/setup-cuda.bat index c8346e2..6ad903a 100755 --- a/setup-cuda.bat +++ b/setup-cuda.bat @@ -1,5 +1,5 @@ git submodule init -git submodule update +git submodule update --remote python -m venv venv call .\venv\Scripts\activate.bat diff --git a/setup-cuda.sh b/setup-cuda.sh index 3da2a1e..5901a1d 100755 --- a/setup-cuda.sh +++ b/setup-cuda.sh @@ -1,6 +1,6 @@ #!/bin/bash git submodule init -git submodule update +git submodule update --remote python3 -m venv venv source ./venv/bin/activate diff --git a/setup-directml.bat b/setup-directml.bat index eff0949..2cc193e 100755 --- a/setup-directml.bat +++ b/setup-directml.bat @@ -1,5 +1,5 @@ git submodule init -git submodule update +git submodule update --remote python -m venv venv call .\venv\Scripts\activate.bat diff --git a/setup-rocm.sh b/setup-rocm.sh index 302d65e..53c0a12 100755 --- a/setup-rocm.sh +++ b/setup-rocm.sh @@ -1,6 +1,6 @@ #!/bin/bash git submodule init -git submodule update +git submodule update --remote python3 -m venv venv source ./venv/bin/activate diff --git a/src/train.py b/src/train.py index 900261f..504fef9 100755 --- a/src/train.py +++ b/src/train.py @@ -2,8 +2,6 @@ import os import sys import argparse - - # this is some massive kludge that only works if it's called from a shell and not an import/PIP package # it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell @@ -19,16 +17,6 @@ sys.path.insert(0, './dlas/') # don't even really bother trying to get DLAS PIP'd # without kludge, it'll have to be accessible as `codes` and not `dlas` -import torch_intermediary -# could just move this auto-toggle into the MITM script -try: - import bitsandbytes as bnb - torch_intermediary.OVERRIDE_ADAM = True - torch_intermediary.OVERRIDE_ADAMW = True -except Exception as e: - torch_intermediary.OVERRIDE_ADAM = False - torch_intermediary.OVERRIDE_ADAMW = False - import torch from codes import train as tr from utils import util, options as option @@ -64,6 +52,13 @@ def train(yaml, launcher='none'): trainer.do_training() if __name__ == "__main__": + try: + import torch_intermediary + if torch_intermediary.OVERRIDE_ADAM: + print("Using BitsAndBytes ADAMW optimizations") + except Exception as e: + pass + parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')