add data-dir flag and set all user data directories based on it
This commit is contained in:
parent
9beb794e0b
commit
5eee2ac398
|
@ -7,7 +7,7 @@ import git
|
||||||
from modules import paths, shared
|
from modules import paths, shared
|
||||||
|
|
||||||
extensions = []
|
extensions = []
|
||||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
extensions_dir = os.path.join(paths.data_path, "extensions")
|
||||||
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.shared import script_path
|
from modules.paths import data_path, script_path
|
||||||
from modules import shared, ui_tempdir, script_callbacks
|
from modules import shared, ui_tempdir, script_callbacks
|
||||||
import tempfile
|
import tempfile
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -289,7 +289,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||||
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
||||||
def paste_func(prompt):
|
def paste_func(prompt):
|
||||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
||||||
filename = os.path.join(script_path, "params.txt")
|
filename = os.path.join(data_path, "params.txt")
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
with open(filename, "r", encoding="utf8") as file:
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
prompt = file.read()
|
prompt = file.read()
|
||||||
|
|
|
@ -6,12 +6,11 @@ import facexlib
|
||||||
import gfpgan
|
import gfpgan
|
||||||
|
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
from modules import shared, devices, modelloader
|
from modules import paths, shared, devices, modelloader
|
||||||
from modules.paths import models_path
|
|
||||||
|
|
||||||
model_dir = "GFPGAN"
|
model_dir = "GFPGAN"
|
||||||
user_path = None
|
user_path = None
|
||||||
model_path = os.path.join(models_path, model_dir)
|
model_path = os.path.join(paths.models_path, model_dir)
|
||||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||||
have_gfpgan = False
|
have_gfpgan = False
|
||||||
loaded_gfpgan_model = None
|
loaded_gfpgan_model = None
|
||||||
|
|
|
@ -4,8 +4,10 @@ import os.path
|
||||||
|
|
||||||
import filelock
|
import filelock
|
||||||
|
|
||||||
|
from modules.paths import data_path
|
||||||
|
|
||||||
cache_filename = "cache.json"
|
|
||||||
|
cache_filename = os.path.join(data_path, "cache.json")
|
||||||
cache_data = None
|
cache_data = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import devices, paths, lowvram, modelloader, errors
|
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
clip_model_name = 'ViT-L/14'
|
clip_model_name = 'ViT-L/14'
|
||||||
|
|
|
@ -4,7 +4,15 @@ import sys
|
||||||
import modules.safe
|
import modules.safe
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
models_path = os.path.join(script_path, "models")
|
|
||||||
|
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||||
|
cmd_opts_pre = parser.parse_known_args()[0]
|
||||||
|
data_path = cmd_opts_pre.data_dir
|
||||||
|
models_path = os.path.join(data_path, "models")
|
||||||
|
|
||||||
|
# data_path = cmd_opts_pre.data
|
||||||
sys.path.insert(0, script_path)
|
sys.path.insert(0, script_path)
|
||||||
|
|
||||||
# search for directory of stable diffusion in following places
|
# search for directory of stable diffusion in following places
|
||||||
|
|
|
@ -17,6 +17,7 @@ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, gener
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
import modules.paths as paths
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
import modules.images as images
|
import modules.images as images
|
||||||
import modules.styles
|
import modules.styles
|
||||||
|
@ -584,7 +585,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
if not p.disable_extra_networks:
|
if not p.disable_extra_networks:
|
||||||
extra_networks.activate(p, extra_network_data)
|
extra_networks.activate(p, extra_network_data)
|
||||||
|
|
||||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
||||||
processed = Processed(p, [], p.seed, "")
|
processed = Processed(p, [], p.seed, "")
|
||||||
file.write(processed.infotext(p, 0))
|
file.write(processed.infotext(p, 0))
|
||||||
|
|
||||||
|
|
|
@ -12,13 +12,13 @@ import ldm.modules.midas as midas
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||||
|
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
checkpoint_alisases = {}
|
checkpoint_alisases = {}
|
||||||
|
@ -307,7 +307,7 @@ def enable_midas_autodownload():
|
||||||
location automatically.
|
location automatically.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
midas_path = os.path.join(models_path, 'midas')
|
midas_path = os.path.join(paths.models_path, 'midas')
|
||||||
|
|
||||||
# stable-diffusion-stability-ai hard-codes the midas model path to
|
# stable-diffusion-stability-ai hard-codes the midas model path to
|
||||||
# a location that differs from where other scripts using this model look.
|
# a location that differs from where other scripts using this model look.
|
||||||
|
|
|
@ -3,13 +3,12 @@ import safetensors.torch
|
||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from modules import shared, devices, script_callbacks, sd_models
|
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||||
from modules.paths import models_path
|
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
vae_path = os.path.abspath(os.path.join(models_path, "VAE"))
|
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
|
||||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||||
vae_dict = {}
|
vae_dict = {}
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ import modules.memmon
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
|
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
|
||||||
from modules.paths import models_path, script_path
|
from modules.paths import models_path, script_path, data_path
|
||||||
|
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
|
@ -25,6 +25,7 @@ sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
default_sd_model_file = sd_model_file
|
default_sd_model_file = sd_model_file
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||||
|
@ -35,7 +36,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo
|
||||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||||
|
@ -74,16 +75,16 @@ parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for sp
|
||||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
|
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
|
|
|
@ -6,8 +6,7 @@ import sys
|
||||||
import tqdm
|
import tqdm
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from modules import shared, images, deepbooru
|
from modules import paths, shared, images, deepbooru
|
||||||
from modules.paths import models_path
|
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
from modules.textual_inversion import autocrop
|
from modules.textual_inversion import autocrop
|
||||||
|
|
||||||
|
@ -199,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||||
|
|
||||||
dnn_model_path = None
|
dnn_model_path = None
|
||||||
try:
|
try:
|
||||||
dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
|
dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
|
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
|
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path, data_path
|
||||||
|
|
||||||
from modules.shared import opts, cmd_opts, restricted_opts
|
from modules.shared import opts, cmd_opts, restricted_opts
|
||||||
|
|
||||||
|
@ -1497,8 +1497,8 @@ def create_ui():
|
||||||
with open(cssfile, "r", encoding="utf8") as file:
|
with open(cssfile, "r", encoding="utf8") as file:
|
||||||
css += file.read() + "\n"
|
css += file.read() + "\n"
|
||||||
|
|
||||||
if os.path.exists(os.path.join(script_path, "user.css")):
|
if os.path.exists(os.path.join(data_path, "user.css")):
|
||||||
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
|
with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
|
||||||
css += file.read() + "\n"
|
css += file.read() + "\n"
|
||||||
|
|
||||||
if not cmd_opts.no_progressbar_hiding:
|
if not cmd_opts.no_progressbar_hiding:
|
||||||
|
|
|
@ -132,7 +132,7 @@ def install_extension_from_url(dirname, url):
|
||||||
normalized_url = normalize_git_url(url)
|
normalized_url = normalize_git_url(url)
|
||||||
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
|
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
|
||||||
|
|
||||||
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
|
tmpdir = os.path.join(paths.data_path, "tmp", dirname)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(tmpdir, True)
|
shutil.rmtree(tmpdir, True)
|
||||||
|
|
|
@ -11,7 +11,6 @@ from modules import modelloader, shared
|
||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
|
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
|
||||||
from modules.paths import models_path
|
|
||||||
|
|
||||||
|
|
||||||
class Upscaler:
|
class Upscaler:
|
||||||
|
@ -39,7 +38,7 @@ class Upscaler:
|
||||||
self.mod_scale = None
|
self.mod_scale = None
|
||||||
|
|
||||||
if self.model_path is None and self.name:
|
if self.model_path is None and self.name:
|
||||||
self.model_path = os.path.join(models_path, self.name)
|
self.model_path = os.path.join(shared.models_path, self.name)
|
||||||
if self.model_path and create_dirs:
|
if self.model_path and create_dirs:
|
||||||
os.makedirs(self.model_path, exist_ok=True)
|
os.makedirs(self.model_path, exist_ok=True)
|
||||||
|
|
||||||
|
@ -143,4 +142,4 @@ class UpscalerNearest(Upscaler):
|
||||||
def __init__(self, dirname=None):
|
def __init__(self, dirname=None):
|
||||||
super().__init__(False)
|
super().__init__(False)
|
||||||
self.name = "Nearest"
|
self.name = "Nearest"
|
||||||
self.scalers = [UpscalerData("Nearest", None, self)]
|
self.scalers = [UpscalerData("Nearest", None, self)]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user