add data-dir flag and set all user data directories based on it

This commit is contained in:
Max Audron 2023-01-25 17:15:42 +01:00
parent 9beb794e0b
commit 5eee2ac398
14 changed files with 39 additions and 31 deletions

View File

@ -7,7 +7,7 @@ import git
from modules import paths, shared from modules import paths, shared
extensions = [] extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions") extensions_dir = os.path.join(paths.data_path, "extensions")
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")

View File

@ -6,7 +6,7 @@ import re
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
from modules.shared import script_path from modules.paths import data_path, script_path
from modules import shared, ui_tempdir, script_callbacks from modules import shared, ui_tempdir, script_callbacks
import tempfile import tempfile
from PIL import Image from PIL import Image
@ -289,7 +289,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
def connect_paste(button, paste_fields, input_comp, jsfunc=None): def connect_paste(button, paste_fields, input_comp, jsfunc=None):
def paste_func(prompt): def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config: if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt") filename = os.path.join(data_path, "params.txt")
if os.path.exists(filename): if os.path.exists(filename):
with open(filename, "r", encoding="utf8") as file: with open(filename, "r", encoding="utf8") as file:
prompt = file.read() prompt = file.read()

View File

@ -6,12 +6,11 @@ import facexlib
import gfpgan import gfpgan
import modules.face_restoration import modules.face_restoration
from modules import shared, devices, modelloader from modules import paths, shared, devices, modelloader
from modules.paths import models_path
model_dir = "GFPGAN" model_dir = "GFPGAN"
user_path = None user_path = None
model_path = os.path.join(models_path, model_dir) model_path = os.path.join(paths.models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False have_gfpgan = False
loaded_gfpgan_model = None loaded_gfpgan_model = None

View File

@ -4,8 +4,10 @@ import os.path
import filelock import filelock
from modules.paths import data_path
cache_filename = "cache.json"
cache_filename = os.path.join(data_path, "cache.json")
cache_data = None cache_data = None

View File

@ -12,7 +12,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared import modules.shared as shared
from modules import devices, paths, lowvram, modelloader, errors from modules import devices, paths, shared, lowvram, modelloader, errors
blip_image_eval_size = 384 blip_image_eval_size = 384
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'

View File

@ -4,7 +4,15 @@ import sys
import modules.safe import modules.safe
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models")
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
cmd_opts_pre = parser.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models")
# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path) sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places # search for directory of stable diffusion in following places

View File

@ -17,6 +17,7 @@ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, gener
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.paths as paths
import modules.face_restoration import modules.face_restoration
import modules.images as images import modules.images as images
import modules.styles import modules.styles
@ -584,7 +585,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if not p.disable_extra_networks: if not p.disable_extra_networks:
extra_networks.activate(p, extra_network_data) extra_networks.activate(p, extra_network_data)
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "") processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0)) file.write(processed.infotext(p, 0))

View File

@ -12,13 +12,13 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules.paths import models_path from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer from modules.timer import Timer
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
checkpoints_list = {} checkpoints_list = {}
checkpoint_alisases = {} checkpoint_alisases = {}
@ -307,7 +307,7 @@ def enable_midas_autodownload():
location automatically. location automatically.
""" """
midas_path = os.path.join(models_path, 'midas') midas_path = os.path.join(paths.models_path, 'midas')
# stable-diffusion-stability-ai hard-codes the midas model path to # stable-diffusion-stability-ai hard-codes the midas model path to
# a location that differs from where other scripts using this model look. # a location that differs from where other scripts using this model look.

View File

@ -3,13 +3,12 @@ import safetensors.torch
import os import os
import collections import collections
from collections import namedtuple from collections import namedtuple
from modules import shared, devices, script_callbacks, sd_models from modules import paths, shared, devices, script_callbacks, sd_models
from modules.paths import models_path
import glob import glob
from copy import deepcopy from copy import deepcopy
vae_path = os.path.abspath(os.path.join(models_path, "VAE")) vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
vae_dict = {} vae_dict = {}

View File

@ -14,7 +14,7 @@ import modules.memmon
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import localization, extensions, script_loading, errors, ui_components, shared_items from modules import localization, extensions, script_loading, errors, ui_components, shared_items
from modules.paths import models_path, script_path from modules.paths import models_path, script_path, data_path
demo = None demo = None
@ -25,6 +25,7 @@ sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
@ -35,7 +36,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
@ -74,16 +75,16 @@ parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for sp
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)

View File

@ -6,8 +6,7 @@ import sys
import tqdm import tqdm
import time import time
from modules import shared, images, deepbooru from modules import paths, shared, images, deepbooru
from modules.paths import models_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop from modules.textual_inversion import autocrop
@ -199,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
dnn_model_path = None dnn_model_path = None
try: try:
dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv")) dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
except Exception as e: except Exception as e:
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)

View File

@ -21,7 +21,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path from modules.paths import script_path, data_path
from modules.shared import opts, cmd_opts, restricted_opts from modules.shared import opts, cmd_opts, restricted_opts
@ -1497,8 +1497,8 @@ def create_ui():
with open(cssfile, "r", encoding="utf8") as file: with open(cssfile, "r", encoding="utf8") as file:
css += file.read() + "\n" css += file.read() + "\n"
if os.path.exists(os.path.join(script_path, "user.css")): if os.path.exists(os.path.join(data_path, "user.css")):
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
css += file.read() + "\n" css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding: if not cmd_opts.no_progressbar_hiding:

View File

@ -132,7 +132,7 @@ def install_extension_from_url(dirname, url):
normalized_url = normalize_git_url(url) normalized_url = normalize_git_url(url)
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
tmpdir = os.path.join(paths.script_path, "tmp", dirname) tmpdir = os.path.join(paths.data_path, "tmp", dirname)
try: try:
shutil.rmtree(tmpdir, True) shutil.rmtree(tmpdir, True)

View File

@ -11,7 +11,6 @@ from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
from modules.paths import models_path
class Upscaler: class Upscaler:
@ -39,7 +38,7 @@ class Upscaler:
self.mod_scale = None self.mod_scale = None
if self.model_path is None and self.name: if self.model_path is None and self.name:
self.model_path = os.path.join(models_path, self.name) self.model_path = os.path.join(shared.models_path, self.name)
if self.model_path and create_dirs: if self.model_path and create_dirs:
os.makedirs(self.model_path, exist_ok=True) os.makedirs(self.model_path, exist_ok=True)
@ -143,4 +142,4 @@ class UpscalerNearest(Upscaler):
def __init__(self, dirname=None): def __init__(self, dirname=None):
super().__init__(False) super().__init__(False)
self.name = "Nearest" self.name = "Nearest"
self.scalers = [UpscalerData("Nearest", None, self)] self.scalers = [UpscalerData("Nearest", None, self)]