forked from mrq/ai-voice-cloning
removed the logic to toggle BNB capabilities, since I guess I can't do that from outside the module
This commit is contained in:
parent
225dee22d4
commit
941a27d2b3
|
@ -1,5 +1,5 @@
|
|||
git submodule init
|
||||
git submodule update
|
||||
git submodule update --remote
|
||||
|
||||
python -m venv venv
|
||||
call .\venv\Scripts\activate.bat
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
git submodule init
|
||||
git submodule update
|
||||
git submodule update --remote
|
||||
|
||||
python3 -m venv venv
|
||||
source ./venv/bin/activate
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
git submodule init
|
||||
git submodule update
|
||||
git submodule update --remote
|
||||
|
||||
python -m venv venv
|
||||
call .\venv\Scripts\activate.bat
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
git submodule init
|
||||
git submodule update
|
||||
git submodule update --remote
|
||||
|
||||
python3 -m venv venv
|
||||
source ./venv/bin/activate
|
||||
|
|
19
src/train.py
19
src/train.py
|
@ -2,8 +2,6 @@ import os
|
|||
import sys
|
||||
import argparse
|
||||
|
||||
|
||||
|
||||
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
||||
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
||||
|
||||
|
@ -19,16 +17,6 @@ sys.path.insert(0, './dlas/')
|
|||
# don't even really bother trying to get DLAS PIP'd
|
||||
# without kludge, it'll have to be accessible as `codes` and not `dlas`
|
||||
|
||||
import torch_intermediary
|
||||
# could just move this auto-toggle into the MITM script
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
torch_intermediary.OVERRIDE_ADAM = True
|
||||
torch_intermediary.OVERRIDE_ADAMW = True
|
||||
except Exception as e:
|
||||
torch_intermediary.OVERRIDE_ADAM = False
|
||||
torch_intermediary.OVERRIDE_ADAMW = False
|
||||
|
||||
import torch
|
||||
from codes import train as tr
|
||||
from utils import util, options as option
|
||||
|
@ -64,6 +52,13 @@ def train(yaml, launcher='none'):
|
|||
trainer.do_training()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
import torch_intermediary
|
||||
if torch_intermediary.OVERRIDE_ADAM:
|
||||
print("Using BitsAndBytes ADAMW optimizations")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
|
|
Loading…
Reference in New Issue
Block a user