diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
deleted file mode 100644
index 50c54210..00000000
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ /dev/null
@@ -1,32 +0,0 @@
----
-name: Bug report
-about: Create a report to help us improve
-title: ''
-labels: bug-report
-assignees: ''
-
----
-
-**Describe the bug**
-A clear and concise description of what the bug is.
-
-**To Reproduce**
-Steps to reproduce the behavior:
-1. Go to '...'
-2. Click on '....'
-3. Scroll down to '....'
-4. See error
-
-**Expected behavior**
-A clear and concise description of what you expected to happen.
-
-**Screenshots**
-If applicable, add screenshots to help explain your problem.
-
-**Desktop (please complete the following information):**
- - OS: [e.g. Windows, Linux]
- - Browser [e.g. chrome, safari]
- - Commit revision [looks like this: e68484500f76a33ba477d5a99340ab30451e557b; can be seen when launching webui.bat, or obtained manually by running `git rev-parse HEAD`]
-
-**Additional context**
-Add any other context about the problem here.
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
new file mode 100644
index 00000000..ed372f22
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -0,0 +1,83 @@
+name: Bug Report
+description: You think somethings is broken in the UI
+title: "[Bug]: "
+labels: ["bug-report"]
+
+body:
+  - type: checkboxes
+    attributes:
+      label: Is there an existing issue for this?
+      description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
+      options:
+        - label: I have searched the existing issues and checked the recent builds/commits
+          required: true
+  - type: markdown
+    attributes:
+      value: |
+        *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
+  - type: textarea
+    id: what-did
+    attributes:
+      label: What happened?
+      description: Tell us what happened in a very clear and simple way
+    validations:
+      required: true
+  - type: textarea
+    id: steps
+    attributes:
+      label: Steps to reproduce the problem
+      description: Please provide us with precise step by step information on how to reproduce the bug
+      value: |
+        1. Go to .... 
+        2. Press ....
+        3. ...
+    validations:
+      required: true
+  - type: textarea
+    id: what-should
+    attributes:
+      label: What should have happened?
+      description: tell what you think the normal behavior should be
+    validations:
+      required: true
+  - type: input
+    id: commit
+    attributes:
+      label: Commit where the problem happens
+      description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
+    validations:
+      required: true
+  - type: dropdown
+    id: platforms
+    attributes:
+      label: What platforms do you use to access UI ?
+      multiple: true
+      options:
+        - Windows
+        - Linux
+        - MacOS
+        - iOS
+        - Android
+        - Other/Cloud
+  - type: dropdown
+    id: browsers
+    attributes:
+      label: What browsers do you use to access the UI ?
+      multiple: true
+      options:
+        - Mozilla Firefox
+        - Google Chrome
+        - Brave
+        - Apple Safari
+        - Microsoft Edge
+  - type: textarea
+    id: cmdargs
+    attributes:
+      label: Command Line Arguments
+      description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
+      render: Shell
+  - type: textarea
+    id: misc
+    attributes:
+      label: Additional information, context and logs
+      description: Please provide us with any relevant additional info, context or log output.
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 00000000..f58c94a9
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,5 @@
+blank_issues_enabled: false
+contact_links:
+  - name: WebUI Community Support
+    url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
+    about: Please ask and answer questions here.
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
deleted file mode 100644
index eda42fa7..00000000
--- a/.github/ISSUE_TEMPLATE/feature_request.md
+++ /dev/null
@@ -1,20 +0,0 @@
----
-name: Feature request
-about: Suggest an idea for this project
-title: ''
-labels: 'suggestion'
-assignees: ''
-
----
-
-**Is your feature request related to a problem? Please describe.**
-A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
-
-**Describe the solution you'd like**
-A clear and concise description of what you want to happen.
-
-**Describe alternatives you've considered**
-A clear and concise description of any alternative solutions or features you've considered.
-
-**Additional context**
-Add any other context or screenshots about the feature request here.
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
new file mode 100644
index 00000000..8ca6e21f
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -0,0 +1,40 @@
+name: Feature request
+description: Suggest an idea for this project
+title: "[Feature Request]: "
+labels: ["suggestion"]
+
+body:
+  - type: checkboxes
+    attributes:
+      label: Is there an existing issue for this?
+      description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
+      options:
+        - label: I have searched the existing issues and checked the recent builds/commits
+          required: true
+  - type: markdown
+    attributes:
+      value: |
+        *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
+  - type: textarea
+    id: feature
+    attributes:
+      label: What would your feature do ?
+      description: Tell us about your feature in a very clear and simple way, and what problem it would solve
+    validations:
+      required: true
+  - type: textarea
+    id: workflow
+    attributes:
+      label: Proposed workflow
+      description: Please provide us with step by step information on how you'd like the feature to be accessed and used
+      value: |
+        1. Go to .... 
+        2. Press ....
+        3. ...
+    validations:
+      required: true
+  - type: textarea
+    id: misc
+    attributes:
+      label: Additional information
+      description: Add any other context or screenshots about the feature request here.
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
new file mode 100644
index 00000000..49dc92bd
--- /dev/null
+++ b/.github/workflows/run_tests.yaml
@@ -0,0 +1,31 @@
+name: Run basic features tests on CPU with empty SD model
+
+on:
+  - push
+  - pull_request
+
+jobs:
+  test:
+    runs-on: ubuntu-latest
+    steps:
+      - name: Checkout Code
+        uses: actions/checkout@v3
+      - name: Set up Python 3.10
+        uses: actions/setup-python@v4
+        with:
+          python-version: 3.10.6
+      - uses: actions/cache@v3
+        with:
+          path: ~/.cache/pip
+          key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
+          restore-keys: ${{ runner.os }}-pip-
+      - name: Run tests
+        run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
+      - name: Upload main app stdout-stderr
+        uses: actions/upload-artifact@v3
+        if: always()
+        with:
+          name: stdout-stderr
+          path: |
+            test/stdout.txt
+            test/stderr.txt
diff --git a/.gitignore b/.gitignore
index 69785b3e..21fa26a7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,6 @@
 __pycache__
 *.ckpt
+*.safetensors
 *.pth
 /ESRGAN/*
 /SwinIR/*
@@ -27,3 +28,7 @@ __pycache__
 notification.mp3
 /SwinIR
 /textual_inversion
+.vscode
+/extensions
+/test/stdout.txt
+/test/stderr.txt
diff --git a/CODEOWNERS b/CODEOWNERS
index 935fedcf..7438c9bc 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1 +1,12 @@
 *       @AUTOMATIC1111
+
+# if you were managing a localization and were removed from this file, this is because
+# the intended way to do localizations now is via extensions. See:
+# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
+# Make a repo with your localization and since you are still listed as a collaborator
+# you can add it to the wiki page yourself. This change is because some people complained
+# the git commit log is cluttered with things unrelated to almost everyone and
+# because I believe this is the best overall for the project to handle localizations almost
+# entirely without my oversight.
+
+
diff --git a/README.md b/README.md
index 859a91b6..88250a6b 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
 - One click install and run script (but you still must install python and git)
 - Outpainting
 - Inpainting
+- Color Sketch
 - Prompt Matrix
 - Stable Diffusion Upscale
 - Attention, specify parts of text that the model should pay more attention to
@@ -23,6 +24,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
     - have as many embeddings as you want and use any names you like for them
     - use multiple embeddings with different numbers of vectors per token
     - works with half precision floating point numbers
+    - train embeddings on 8GB (also reports of 6GB working)
 - Extras tab with:
     - GFPGAN, neural network that fixes faces
     - CodeFormer, face restoration tool as an alternative to GFPGAN
@@ -37,14 +39,14 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
 - Interrupt processing at any time
 - 4GB video card support (also reports of 2GB working)
 - Correct seeds for batches
-- Prompt length validation
-     - get length of prompt in tokens as you type
-     - get a warning after generation if some text was truncated
+- Live prompt token length validation
 - Generation parameters
      - parameters you used to generate images are saved with that image
      - in PNG chunks for PNG, in EXIF for JPEG
      - can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
      - can be disabled in settings
+     - drag and drop an image/text-parameters to promptbox
+- Read Generation Parameters Button, loads parameters in promptbox to UI
 - Settings page
 - Running arbitrary python code from UI (must run with --allow-code to enable)
 - Mouseover hints for most UI elements
@@ -59,25 +61,37 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
 - CLIP interrogator, a button that tries to guess prompt from an image
 - Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
 - Batch Processing, process a group of files using img2img
-- Img2img Alternative
+- Img2img Alternative, reverse Euler method of cross attention control
 - Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
 - Reloading checkpoints on the fly
-- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
+- Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one
 - [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
 - [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
      - separate prompts using uppercase `AND`
      - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
 - No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
-- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
+- DeepDanbooru integration, creates danbooru style tags for anime prompts
 - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
+- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
+- Generate forever option
+- Training tab
+     - hypernetworks and embeddings options
+     - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
+- Clip skip
+- Use Hypernetworks
+- Use VAEs
+- Estimated completion time in progress bar
+- API
+- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
+- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
+- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
 
 ## Installation and Running
 Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
 
-Alternatively, use Google Colab:
+Alternatively, use online services (like Google Colab):
 
-- [Colab, maintained by Akaibu](https://colab.research.google.com/drive/1kw3egmSn-KgWsikYvOMjJkVDsPLjEMzl)
-- [Colab, original by me, outdated](https://colab.research.google.com/drive/1Iy-xW9t1-OQWhb0hNxueGij8phCyluOh).
+- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
 
 ### Automatic Installation on Windows
 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
@@ -113,6 +127,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC
 The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
 
 ## Credits
+Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
+
 - Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
 - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
 - GFPGAN - https://github.com/TencentARC/GFPGAN.git
@@ -121,15 +137,17 @@ The documentation was moved from this README over to the project's [wiki](https:
 - SwinIR - https://github.com/JingyunLiang/SwinIR
 - Swin2SR - https://github.com/mv-lab/swin2sr
 - LDSR - https://github.com/Hafiidz/latent-diffusion
+- MiDaS - https://github.com/isl-org/MiDaS
 - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
-- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
-- InvokeAI, lstein - Cross Attention layer optimization - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
-- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
+- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
+- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
+- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
 - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
 - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
 - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
 - Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
 - xformers - https://github.com/facebookresearch/xformers
 - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
+- Security advice - RyotaK
 - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
 - (You)
diff --git a/configs/alt-diffusion-inference.yaml b/configs/alt-diffusion-inference.yaml
new file mode 100644
index 00000000..cfbee72d
--- /dev/null
+++ b/configs/alt-diffusion-inference.yaml
@@ -0,0 +1,72 @@
+model:
+  base_learning_rate: 1.0e-04
+  target: ldm.models.diffusion.ddpm.LatentDiffusion
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "jpg"
+    cond_stage_key: "txt"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: false   # Note: different from the one we trained before
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+
+    scheduler_config: # 10000 warmup steps
+      target: ldm.lr_scheduler.LambdaLinearScheduler
+      params:
+        warm_up_steps: [ 10000 ]
+        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+        f_start: [ 1.e-6 ]
+        f_max: [ 1. ]
+        f_min: [ 1. ]
+
+    unet_config:
+      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: modules.xlmr.BertSeriesModelWithTransformation
+      params:
+        name: "XLMR-Large"
\ No newline at end of file
diff --git a/configs/v1-inference.yaml b/configs/v1-inference.yaml
new file mode 100644
index 00000000..d4effe56
--- /dev/null
+++ b/configs/v1-inference.yaml
@@ -0,0 +1,70 @@
+model:
+  base_learning_rate: 1.0e-04
+  target: ldm.models.diffusion.ddpm.LatentDiffusion
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "jpg"
+    cond_stage_key: "txt"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: false   # Note: different from the one we trained before
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+
+    scheduler_config: # 10000 warmup steps
+      target: ldm.lr_scheduler.LambdaLinearScheduler
+      params:
+        warm_up_steps: [ 10000 ]
+        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+        f_start: [ 1.e-6 ]
+        f_max: [ 1. ]
+        f_min: [ 1. ]
+
+    unet_config:
+      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
diff --git a/modules/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
similarity index 77%
rename from modules/ldsr_model_arch.py
rename to extensions-builtin/LDSR/ldsr_model_arch.py
index 14db5076..0ad49f4e 100644
--- a/modules/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -1,3 +1,4 @@
+import os
 import gc
 import time
 import warnings
@@ -8,27 +9,49 @@ import torchvision
 from PIL import Image
 from einops import rearrange, repeat
 from omegaconf import OmegaConf
+import safetensors.torch
 
 from ldm.models.diffusion.ddim import DDIMSampler
 from ldm.util import instantiate_from_config, ismap
+from modules import shared, sd_hijack
 
 warnings.filterwarnings("ignore", category=UserWarning)
 
+cached_ldsr_model: torch.nn.Module = None
+
 
 # Create LDSR Class
 class LDSR:
     def load_model_from_config(self, half_attention):
-        print(f"Loading model from {self.modelPath}")
-        pl_sd = torch.load(self.modelPath, map_location="cpu")
-        sd = pl_sd["state_dict"]
-        config = OmegaConf.load(self.yamlPath)
-        model = instantiate_from_config(config.model)
-        model.load_state_dict(sd, strict=False)
-        model.cuda()
-        if half_attention:
-            model = model.half()
+        global cached_ldsr_model
+
+        if shared.opts.ldsr_cached and cached_ldsr_model is not None:
+            print("Loading model from cache")
+            model: torch.nn.Module = cached_ldsr_model
+        else:
+            print(f"Loading model from {self.modelPath}")
+            _, extension = os.path.splitext(self.modelPath)
+            if extension.lower() == ".safetensors":
+                pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
+            else:
+                pl_sd = torch.load(self.modelPath, map_location="cpu")
+            sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
+            config = OmegaConf.load(self.yamlPath)
+            config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
+            model: torch.nn.Module = instantiate_from_config(config.model)
+            model.load_state_dict(sd, strict=False)
+            model = model.to(shared.device)
+            if half_attention:
+                model = model.half()
+            if shared.cmd_opts.opt_channelslast:
+                model = model.to(memory_format=torch.channels_last)
+
+            sd_hijack.model_hijack.hijack(model) # apply optimization
+            model.eval()
+
+            if shared.opts.ldsr_cached:
+                cached_ldsr_model = model
 
-        model.eval()
         return {"model": model}
 
     def __init__(self, model_path, yaml_path):
@@ -93,7 +116,8 @@ class LDSR:
         down_sample_method = 'Lanczos'
 
         gc.collect()
-        torch.cuda.empty_cache()
+        if torch.cuda.is_available:
+            torch.cuda.empty_cache()
 
         im_og = image
         width_og, height_og = im_og.size
@@ -101,8 +125,8 @@ class LDSR:
         down_sample_rate = target_scale / 4
         wd = width_og * down_sample_rate
         hd = height_og * down_sample_rate
-        width_downsampled_pre = int(wd)
-        height_downsampled_pre = int(hd)
+        width_downsampled_pre = int(np.ceil(wd))
+        height_downsampled_pre = int(np.ceil(hd))
 
         if down_sample_rate != 1:
             print(
@@ -110,7 +134,12 @@ class LDSR:
             im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
         else:
             print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
-        logs = self.run(model["model"], im_og, diffusion_steps, eta)
+        
+        # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
+        pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
+        im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
+        
+        logs = self.run(model["model"], im_padded, diffusion_steps, eta)
 
         sample = logs["sample"]
         sample = sample.detach().cpu()
@@ -120,9 +149,14 @@ class LDSR:
         sample = np.transpose(sample, (0, 2, 3, 1))
         a = Image.fromarray(sample[0])
 
+        # remove padding
+        a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
+
         del model
         gc.collect()
-        torch.cuda.empty_cache()
+        if torch.cuda.is_available:
+            torch.cuda.empty_cache()
+
         return a
 
 
@@ -137,7 +171,7 @@ def get_cond(selected_path):
     c = rearrange(c, '1 c h w -> 1 h w c')
     c = 2. * c - 1.
 
-    c = c.to(torch.device("cuda"))
+    c = c.to(shared.device)
     example["LR_image"] = c
     example["image"] = c_up
 
diff --git a/extensions-builtin/LDSR/preload.py b/extensions-builtin/LDSR/preload.py
new file mode 100644
index 00000000..d746007c
--- /dev/null
+++ b/extensions-builtin/LDSR/preload.py
@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+    parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
diff --git a/modules/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
similarity index 66%
rename from modules/ldsr_model.py
rename to extensions-builtin/LDSR/scripts/ldsr_model.py
index 8c4db44a..b8cff29b 100644
--- a/modules/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -5,8 +5,9 @@ import traceback
 from basicsr.utils.download_util import load_file_from_url
 
 from modules.upscaler import Upscaler, UpscalerData
-from modules.ldsr_model_arch import LDSR
-from modules import shared
+from ldsr_model_arch import LDSR
+from modules import shared, script_callbacks
+import sd_hijack_autoencoder, sd_hijack_ddpm_v1
 
 
 class UpscalerLDSR(Upscaler):
@@ -24,6 +25,7 @@ class UpscalerLDSR(Upscaler):
         yaml_path = os.path.join(self.model_path, "project.yaml")
         old_model_path = os.path.join(self.model_path, "model.pth")
         new_model_path = os.path.join(self.model_path, "model.ckpt")
+        safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
         if os.path.exists(yaml_path):
             statinfo = os.stat(yaml_path)
             if statinfo.st_size >= 10485760:
@@ -32,8 +34,11 @@ class UpscalerLDSR(Upscaler):
         if os.path.exists(old_model_path):
             print("Renaming model from model.pth to model.ckpt")
             os.rename(old_model_path, new_model_path)
-        model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
-                                   file_name="model.ckpt", progress=True)
+        if os.path.exists(safetensors_model_path):
+            model = safetensors_model_path
+        else:
+            model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+                                       file_name="model.ckpt", progress=True)
         yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
                                   file_name="project.yaml", progress=True)
 
@@ -52,3 +57,13 @@ class UpscalerLDSR(Upscaler):
             return img
         ddim_steps = shared.opts.ldsr_steps
         return ldsr.super_resolution(img, ddim_steps, self.scale)
+
+
+def on_ui_settings():
+    import gradio as gr
+
+    shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
+    shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
new file mode 100644
index 00000000..8e03c7f8
--- /dev/null
+++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
@@ -0,0 +1,286 @@
+# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
+# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
+# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
+
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.util import instantiate_from_config
+
+import ldm.models.autoencoder
+
+class VQModel(pl.LightningModule):
+    def __init__(self,
+                 ddconfig,
+                 lossconfig,
+                 n_embed,
+                 embed_dim,
+                 ckpt_path=None,
+                 ignore_keys=[],
+                 image_key="image",
+                 colorize_nlabels=None,
+                 monitor=None,
+                 batch_resize_range=None,
+                 scheduler_config=None,
+                 lr_g_factor=1.0,
+                 remap=None,
+                 sane_index_shape=False, # tell vector quantizer to return indices as bhw
+                 use_ema=False
+                 ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.n_embed = n_embed
+        self.image_key = image_key
+        self.encoder = Encoder(**ddconfig)
+        self.decoder = Decoder(**ddconfig)
+        self.loss = instantiate_from_config(lossconfig)
+        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
+                                        remap=remap,
+                                        sane_index_shape=sane_index_shape)
+        self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+        if colorize_nlabels is not None:
+            assert type(colorize_nlabels)==int
+            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+        if monitor is not None:
+            self.monitor = monitor
+        self.batch_resize_range = batch_resize_range
+        if self.batch_resize_range is not None:
+            print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
+
+        self.use_ema = use_ema
+        if self.use_ema:
+            self.model_ema = LitEma(self)
+            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+        self.scheduler_config = scheduler_config
+        self.lr_g_factor = lr_g_factor
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.parameters())
+            self.model_ema.copy_to(self)
+            if context is not None:
+                print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.parameters())
+                if context is not None:
+                    print(f"{context}: Restored training weights")
+
+    def init_from_ckpt(self, path, ignore_keys=list()):
+        sd = torch.load(path, map_location="cpu")["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        missing, unexpected = self.load_state_dict(sd, strict=False)
+        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+        if len(missing) > 0:
+            print(f"Missing Keys: {missing}")
+            print(f"Unexpected Keys: {unexpected}")
+
+    def on_train_batch_end(self, *args, **kwargs):
+        if self.use_ema:
+            self.model_ema(self)
+
+    def encode(self, x):
+        h = self.encoder(x)
+        h = self.quant_conv(h)
+        quant, emb_loss, info = self.quantize(h)
+        return quant, emb_loss, info
+
+    def encode_to_prequant(self, x):
+        h = self.encoder(x)
+        h = self.quant_conv(h)
+        return h
+
+    def decode(self, quant):
+        quant = self.post_quant_conv(quant)
+        dec = self.decoder(quant)
+        return dec
+
+    def decode_code(self, code_b):
+        quant_b = self.quantize.embed_code(code_b)
+        dec = self.decode(quant_b)
+        return dec
+
+    def forward(self, input, return_pred_indices=False):
+        quant, diff, (_,_,ind) = self.encode(input)
+        dec = self.decode(quant)
+        if return_pred_indices:
+            return dec, diff, ind
+        return dec, diff
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+        if self.batch_resize_range is not None:
+            lower_size = self.batch_resize_range[0]
+            upper_size = self.batch_resize_range[1]
+            if self.global_step <= 4:
+                # do the first few batches with max size to avoid later oom
+                new_resize = upper_size
+            else:
+                new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
+            if new_resize != x.shape[2]:
+                x = F.interpolate(x, size=new_resize, mode="bicubic")
+            x = x.detach()
+        return x
+
+    def training_step(self, batch, batch_idx, optimizer_idx):
+        # https://github.com/pytorch/pytorch/issues/37142
+        # try not to fool the heuristics
+        x = self.get_input(batch, self.image_key)
+        xrec, qloss, ind = self(x, return_pred_indices=True)
+
+        if optimizer_idx == 0:
+            # autoencode
+            aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+                                            last_layer=self.get_last_layer(), split="train",
+                                            predicted_indices=ind)
+
+            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+            return aeloss
+
+        if optimizer_idx == 1:
+            # discriminator
+            discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+                                            last_layer=self.get_last_layer(), split="train")
+            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+            return discloss
+
+    def validation_step(self, batch, batch_idx):
+        log_dict = self._validation_step(batch, batch_idx)
+        with self.ema_scope():
+            log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
+        return log_dict
+
+    def _validation_step(self, batch, batch_idx, suffix=""):
+        x = self.get_input(batch, self.image_key)
+        xrec, qloss, ind = self(x, return_pred_indices=True)
+        aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
+                                        self.global_step,
+                                        last_layer=self.get_last_layer(),
+                                        split="val"+suffix,
+                                        predicted_indices=ind
+                                        )
+
+        discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
+                                            self.global_step,
+                                            last_layer=self.get_last_layer(),
+                                            split="val"+suffix,
+                                            predicted_indices=ind
+                                            )
+        rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
+        self.log(f"val{suffix}/rec_loss", rec_loss,
+                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+        self.log(f"val{suffix}/aeloss", aeloss,
+                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+        if version.parse(pl.__version__) >= version.parse('1.4.0'):
+            del log_dict_ae[f"val{suffix}/rec_loss"]
+        self.log_dict(log_dict_ae)
+        self.log_dict(log_dict_disc)
+        return self.log_dict
+
+    def configure_optimizers(self):
+        lr_d = self.learning_rate
+        lr_g = self.lr_g_factor*self.learning_rate
+        print("lr_d", lr_d)
+        print("lr_g", lr_g)
+        opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+                                  list(self.decoder.parameters())+
+                                  list(self.quantize.parameters())+
+                                  list(self.quant_conv.parameters())+
+                                  list(self.post_quant_conv.parameters()),
+                                  lr=lr_g, betas=(0.5, 0.9))
+        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+                                    lr=lr_d, betas=(0.5, 0.9))
+
+        if self.scheduler_config is not None:
+            scheduler = instantiate_from_config(self.scheduler_config)
+
+            print("Setting up LambdaLR scheduler...")
+            scheduler = [
+                {
+                    'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
+                    'interval': 'step',
+                    'frequency': 1
+                },
+                {
+                    'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
+                    'interval': 'step',
+                    'frequency': 1
+                },
+            ]
+            return [opt_ae, opt_disc], scheduler
+        return [opt_ae, opt_disc], []
+
+    def get_last_layer(self):
+        return self.decoder.conv_out.weight
+
+    def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.image_key)
+        x = x.to(self.device)
+        if only_inputs:
+            log["inputs"] = x
+            return log
+        xrec, _ = self(x)
+        if x.shape[1] > 3:
+            # colorize with random projection
+            assert xrec.shape[1] > 3
+            x = self.to_rgb(x)
+            xrec = self.to_rgb(xrec)
+        log["inputs"] = x
+        log["reconstructions"] = xrec
+        if plot_ema:
+            with self.ema_scope():
+                xrec_ema, _ = self(x)
+                if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
+                log["reconstructions_ema"] = xrec_ema
+        return log
+
+    def to_rgb(self, x):
+        assert self.image_key == "segmentation"
+        if not hasattr(self, "colorize"):
+            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+        x = F.conv2d(x, weight=self.colorize)
+        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+        return x
+
+
+class VQModelInterface(VQModel):
+    def __init__(self, embed_dim, *args, **kwargs):
+        super().__init__(embed_dim=embed_dim, *args, **kwargs)
+        self.embed_dim = embed_dim
+
+    def encode(self, x):
+        h = self.encoder(x)
+        h = self.quant_conv(h)
+        return h
+
+    def decode(self, h, force_not_quantize=False):
+        # also go through quantization layer
+        if not force_not_quantize:
+            quant, emb_loss, info = self.quantize(h)
+        else:
+            quant = h
+        quant = self.post_quant_conv(quant)
+        dec = self.decoder(quant)
+        return dec
+
+setattr(ldm.models.autoencoder, "VQModel", VQModel)
+setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
new file mode 100644
index 00000000..5c0488e5
--- /dev/null
+++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
@@ -0,0 +1,1449 @@
+# This script is copied from the compvis/stable-diffusion repo (aka the SD V1 repo)
+# Original filename: ldm/models/diffusion/ddpm.py
+# The purpose to reinstate the old DDPM logic which works with VQ, whereas the V2 one doesn't
+# Some models such as LDSR require VQ to work correctly
+# The classes are suffixed with "V1" and added back to the "ldm.models.diffusion.ddpm" module
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager
+from functools import partial
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+
+import ldm.models.diffusion.ddpm
+
+__conditioning_keys__ = {'concat': 'c_concat',
+                         'crossattn': 'c_crossattn',
+                         'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+    return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPMV1(pl.LightningModule):
+    # classic DDPM with Gaussian diffusion, in image space
+    def __init__(self,
+                 unet_config,
+                 timesteps=1000,
+                 beta_schedule="linear",
+                 loss_type="l2",
+                 ckpt_path=None,
+                 ignore_keys=[],
+                 load_only_unet=False,
+                 monitor="val/loss",
+                 use_ema=True,
+                 first_stage_key="image",
+                 image_size=256,
+                 channels=3,
+                 log_every_t=100,
+                 clip_denoised=True,
+                 linear_start=1e-4,
+                 linear_end=2e-2,
+                 cosine_s=8e-3,
+                 given_betas=None,
+                 original_elbo_weight=0.,
+                 v_posterior=0.,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+                 l_simple_weight=1.,
+                 conditioning_key=None,
+                 parameterization="eps",  # all assuming fixed variance schedules
+                 scheduler_config=None,
+                 use_positional_encodings=False,
+                 learn_logvar=False,
+                 logvar_init=0.,
+                 ):
+        super().__init__()
+        assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
+        self.parameterization = parameterization
+        print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+        self.cond_stage_model = None
+        self.clip_denoised = clip_denoised
+        self.log_every_t = log_every_t
+        self.first_stage_key = first_stage_key
+        self.image_size = image_size  # try conv?
+        self.channels = channels
+        self.use_positional_encodings = use_positional_encodings
+        self.model = DiffusionWrapperV1(unet_config, conditioning_key)
+        count_params(self.model, verbose=True)
+        self.use_ema = use_ema
+        if self.use_ema:
+            self.model_ema = LitEma(self.model)
+            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        self.use_scheduler = scheduler_config is not None
+        if self.use_scheduler:
+            self.scheduler_config = scheduler_config
+
+        self.v_posterior = v_posterior
+        self.original_elbo_weight = original_elbo_weight
+        self.l_simple_weight = l_simple_weight
+
+        if monitor is not None:
+            self.monitor = monitor
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+
+        self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+        self.loss_type = loss_type
+
+        self.learn_logvar = learn_logvar
+        self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+        if self.learn_logvar:
+            self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+
+    def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+        if exists(given_betas):
+            betas = given_betas
+        else:
+            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+                                       cosine_s=cosine_s)
+        alphas = 1. - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+        timesteps, = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.linear_start = linear_start
+        self.linear_end = linear_end
+        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer('betas', to_torch(betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+        # calculations for posterior q(x_{t-1} | x_t, x_0)
+        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+                    1. - alphas_cumprod) + self.v_posterior * betas
+        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+        self.register_buffer('posterior_variance', to_torch(posterior_variance))
+        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+        self.register_buffer('posterior_mean_coef1', to_torch(
+            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+        self.register_buffer('posterior_mean_coef2', to_torch(
+            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+        if self.parameterization == "eps":
+            lvlb_weights = self.betas ** 2 / (
+                        2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+        elif self.parameterization == "x0":
+            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+        else:
+            raise NotImplementedError("mu not supported")
+        # TODO how to choose this term
+        lvlb_weights[0] = lvlb_weights[1]
+        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+        assert not torch.isnan(self.lvlb_weights).all()
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.model.parameters())
+            self.model_ema.copy_to(self.model)
+            if context is not None:
+                print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.model.parameters())
+                if context is not None:
+                    print(f"{context}: Restored training weights")
+
+    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+        sd = torch.load(path, map_location="cpu")
+        if "state_dict" in list(sd.keys()):
+            sd = sd["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+            sd, strict=False)
+        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+        if len(missing) > 0:
+            print(f"Missing Keys: {missing}")
+        if len(unexpected) > 0:
+            print(f"Unexpected Keys: {unexpected}")
+
+    def q_mean_variance(self, x_start, t):
+        """
+        Get the distribution q(x_t | x_0).
+        :param x_start: the [N x C x ...] tensor of noiseless inputs.
+        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+        """
+        mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+        log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+        return mean, variance, log_variance
+
+    def predict_start_from_noise(self, x_t, t, noise):
+        return (
+                extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+                extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+        )
+
+    def q_posterior(self, x_start, x_t, t):
+        posterior_mean = (
+                extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+                extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+        )
+        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+        posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+        return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+    def p_mean_variance(self, x, t, clip_denoised: bool):
+        model_out = self.model(x, t)
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        if clip_denoised:
+            x_recon.clamp_(-1., 1.)
+
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+        return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+        b, *_, device = *x.shape, x.device
+        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+        noise = noise_like(x.shape, device, repeat_noise)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def p_sample_loop(self, shape, return_intermediates=False):
+        device = self.betas.device
+        b = shape[0]
+        img = torch.randn(shape, device=device)
+        intermediates = [img]
+        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+                                clip_denoised=self.clip_denoised)
+            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+                intermediates.append(img)
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(self, batch_size=16, return_intermediates=False):
+        image_size = self.image_size
+        channels = self.channels
+        return self.p_sample_loop((batch_size, channels, image_size, image_size),
+                                  return_intermediates=return_intermediates)
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+    def get_loss(self, pred, target, mean=True):
+        if self.loss_type == 'l1':
+            loss = (target - pred).abs()
+            if mean:
+                loss = loss.mean()
+        elif self.loss_type == 'l2':
+            if mean:
+                loss = torch.nn.functional.mse_loss(target, pred)
+            else:
+                loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+        else:
+            raise NotImplementedError("unknown loss type '{loss_type}'")
+
+        return loss
+
+    def p_losses(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+        model_out = self.model(x_noisy, t)
+
+        loss_dict = {}
+        if self.parameterization == "eps":
+            target = noise
+        elif self.parameterization == "x0":
+            target = x_start
+        else:
+            raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
+
+        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+        log_prefix = 'train' if self.training else 'val'
+
+        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+        loss_simple = loss.mean() * self.l_simple_weight
+
+        loss_vlb = (self.lvlb_weights[t] * loss).mean()
+        loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+        loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+        loss_dict.update({f'{log_prefix}/loss': loss})
+
+        return loss, loss_dict
+
+    def forward(self, x, *args, **kwargs):
+        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+        return self.p_losses(x, t, *args, **kwargs)
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = rearrange(x, 'b h w c -> b c h w')
+        x = x.to(memory_format=torch.contiguous_format).float()
+        return x
+
+    def shared_step(self, batch):
+        x = self.get_input(batch, self.first_stage_key)
+        loss, loss_dict = self(x)
+        return loss, loss_dict
+
+    def training_step(self, batch, batch_idx):
+        loss, loss_dict = self.shared_step(batch)
+
+        self.log_dict(loss_dict, prog_bar=True,
+                      logger=True, on_step=True, on_epoch=True)
+
+        self.log("global_step", self.global_step,
+                 prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+        if self.use_scheduler:
+            lr = self.optimizers().param_groups[0]['lr']
+            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+        return loss
+
+    @torch.no_grad()
+    def validation_step(self, batch, batch_idx):
+        _, loss_dict_no_ema = self.shared_step(batch)
+        with self.ema_scope():
+            _, loss_dict_ema = self.shared_step(batch)
+            loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+        self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+        self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+    def on_train_batch_end(self, *args, **kwargs):
+        if self.use_ema:
+            self.model_ema(self.model)
+
+    def _get_rows_from_list(self, samples):
+        n_imgs_per_row = len(samples)
+        denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+        return denoise_grid
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.first_stage_key)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        x = x.to(self.device)[:N]
+        log["inputs"] = x
+
+        # get diffusion row
+        diffusion_row = list()
+        x_start = x[:n_row]
+
+        for t in range(self.num_timesteps):
+            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                t = t.to(self.device).long()
+                noise = torch.randn_like(x_start)
+                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+                diffusion_row.append(x_noisy)
+
+        log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+        if sample:
+            # get denoise row
+            with self.ema_scope("Plotting"):
+                samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+            log["samples"] = samples
+            log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+        if return_keys:
+            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+                return log
+            else:
+                return {key: log[key] for key in return_keys}
+        return log
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+        if self.learn_logvar:
+            params = params + [self.logvar]
+        opt = torch.optim.AdamW(params, lr=lr)
+        return opt
+
+
+class LatentDiffusionV1(DDPMV1):
+    """main class"""
+    def __init__(self,
+                 first_stage_config,
+                 cond_stage_config,
+                 num_timesteps_cond=None,
+                 cond_stage_key="image",
+                 cond_stage_trainable=False,
+                 concat_mode=True,
+                 cond_stage_forward=None,
+                 conditioning_key=None,
+                 scale_factor=1.0,
+                 scale_by_std=False,
+                 *args, **kwargs):
+        self.num_timesteps_cond = default(num_timesteps_cond, 1)
+        self.scale_by_std = scale_by_std
+        assert self.num_timesteps_cond <= kwargs['timesteps']
+        # for backwards compatibility after implementation of DiffusionWrapper
+        if conditioning_key is None:
+            conditioning_key = 'concat' if concat_mode else 'crossattn'
+        if cond_stage_config == '__is_unconditional__':
+            conditioning_key = None
+        ckpt_path = kwargs.pop("ckpt_path", None)
+        ignore_keys = kwargs.pop("ignore_keys", [])
+        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+        self.concat_mode = concat_mode
+        self.cond_stage_trainable = cond_stage_trainable
+        self.cond_stage_key = cond_stage_key
+        try:
+            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+        except:
+            self.num_downs = 0
+        if not scale_by_std:
+            self.scale_factor = scale_factor
+        else:
+            self.register_buffer('scale_factor', torch.tensor(scale_factor))
+        self.instantiate_first_stage(first_stage_config)
+        self.instantiate_cond_stage(cond_stage_config)
+        self.cond_stage_forward = cond_stage_forward
+        self.clip_denoised = False
+        self.bbox_tokenizer = None  
+
+        self.restarted_from_ckpt = False
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys)
+            self.restarted_from_ckpt = True
+
+    def make_cond_schedule(self, ):
+        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+        self.cond_ids[:self.num_timesteps_cond] = ids
+
+    @rank_zero_only
+    @torch.no_grad()
+    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+        # only for very first batch
+        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+            assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+            # set rescale weight to 1./std of encodings
+            print("### USING STD-RESCALING ###")
+            x = super().get_input(batch, self.first_stage_key)
+            x = x.to(self.device)
+            encoder_posterior = self.encode_first_stage(x)
+            z = self.get_first_stage_encoding(encoder_posterior).detach()
+            del self.scale_factor
+            self.register_buffer('scale_factor', 1. / z.flatten().std())
+            print(f"setting self.scale_factor to {self.scale_factor}")
+            print("### USING STD-RESCALING ###")
+
+    def register_schedule(self,
+                          given_betas=None, beta_schedule="linear", timesteps=1000,
+                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+        super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+        self.shorten_cond_schedule = self.num_timesteps_cond > 1
+        if self.shorten_cond_schedule:
+            self.make_cond_schedule()
+
+    def instantiate_first_stage(self, config):
+        model = instantiate_from_config(config)
+        self.first_stage_model = model.eval()
+        self.first_stage_model.train = disabled_train
+        for param in self.first_stage_model.parameters():
+            param.requires_grad = False
+
+    def instantiate_cond_stage(self, config):
+        if not self.cond_stage_trainable:
+            if config == "__is_first_stage__":
+                print("Using first stage also as cond stage.")
+                self.cond_stage_model = self.first_stage_model
+            elif config == "__is_unconditional__":
+                print(f"Training {self.__class__.__name__} as an unconditional model.")
+                self.cond_stage_model = None
+                # self.be_unconditional = True
+            else:
+                model = instantiate_from_config(config)
+                self.cond_stage_model = model.eval()
+                self.cond_stage_model.train = disabled_train
+                for param in self.cond_stage_model.parameters():
+                    param.requires_grad = False
+        else:
+            assert config != '__is_first_stage__'
+            assert config != '__is_unconditional__'
+            model = instantiate_from_config(config)
+            self.cond_stage_model = model
+
+    def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+        denoise_row = []
+        for zd in tqdm(samples, desc=desc):
+            denoise_row.append(self.decode_first_stage(zd.to(self.device),
+                                                            force_not_quantize=force_no_decoder_quantization))
+        n_imgs_per_row = len(denoise_row)
+        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W
+        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+        return denoise_grid
+
+    def get_first_stage_encoding(self, encoder_posterior):
+        if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+            z = encoder_posterior.sample()
+        elif isinstance(encoder_posterior, torch.Tensor):
+            z = encoder_posterior
+        else:
+            raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+        return self.scale_factor * z
+
+    def get_learned_conditioning(self, c):
+        if self.cond_stage_forward is None:
+            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+                c = self.cond_stage_model.encode(c)
+                if isinstance(c, DiagonalGaussianDistribution):
+                    c = c.mode()
+            else:
+                c = self.cond_stage_model(c)
+        else:
+            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+        return c
+
+    def meshgrid(self, h, w):
+        y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+        x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+        arr = torch.cat([y, x], dim=-1)
+        return arr
+
+    def delta_border(self, h, w):
+        """
+        :param h: height
+        :param w: width
+        :return: normalized distance to image border,
+         wtith min distance = 0 at border and max dist = 0.5 at image center
+        """
+        lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+        arr = self.meshgrid(h, w) / lower_right_corner
+        dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+        dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+        edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+        return edge_dist
+
+    def get_weighting(self, h, w, Ly, Lx, device):
+        weighting = self.delta_border(h, w)
+        weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+                               self.split_input_params["clip_max_weight"], )
+        weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+        if self.split_input_params["tie_braker"]:
+            L_weighting = self.delta_border(Ly, Lx)
+            L_weighting = torch.clip(L_weighting,
+                                     self.split_input_params["clip_min_tie_weight"],
+                                     self.split_input_params["clip_max_tie_weight"])
+
+            L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+            weighting = weighting * L_weighting
+        return weighting
+
+    def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):  # todo load once not every time, shorten code
+        """
+        :param x: img of size (bs, c, h, w)
+        :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+        """
+        bs, nc, h, w = x.shape
+
+        # number of crops in image
+        Ly = (h - kernel_size[0]) // stride[0] + 1
+        Lx = (w - kernel_size[1]) // stride[1] + 1
+
+        if uf == 1 and df == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+            weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+        elif uf > 1 and df == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+                                dilation=1, padding=0,
+                                stride=(stride[0] * uf, stride[1] * uf))
+            fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+            weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h * uf, w * uf)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+        elif df > 1 and uf == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+                                dilation=1, padding=0,
+                                stride=(stride[0] // df, stride[1] // df))
+            fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+            weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h // df, w // df)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+        else:
+            raise NotImplementedError
+
+        return fold, unfold, normalization, weighting
+
+    @torch.no_grad()
+    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+                  cond_key=None, return_original_cond=False, bs=None):
+        x = super().get_input(batch, k)
+        if bs is not None:
+            x = x[:bs]
+        x = x.to(self.device)
+        encoder_posterior = self.encode_first_stage(x)
+        z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+        if self.model.conditioning_key is not None:
+            if cond_key is None:
+                cond_key = self.cond_stage_key
+            if cond_key != self.first_stage_key:
+                if cond_key in ['caption', 'coordinates_bbox']:
+                    xc = batch[cond_key]
+                elif cond_key == 'class_label':
+                    xc = batch
+                else:
+                    xc = super().get_input(batch, cond_key).to(self.device)
+            else:
+                xc = x
+            if not self.cond_stage_trainable or force_c_encode:
+                if isinstance(xc, dict) or isinstance(xc, list):
+                    # import pudb; pudb.set_trace()
+                    c = self.get_learned_conditioning(xc)
+                else:
+                    c = self.get_learned_conditioning(xc.to(self.device))
+            else:
+                c = xc
+            if bs is not None:
+                c = c[:bs]
+
+            if self.use_positional_encodings:
+                pos_x, pos_y = self.compute_latent_shifts(batch)
+                ckey = __conditioning_keys__[self.model.conditioning_key]
+                c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+        else:
+            c = None
+            xc = None
+            if self.use_positional_encodings:
+                pos_x, pos_y = self.compute_latent_shifts(batch)
+                c = {'pos_x': pos_x, 'pos_y': pos_y}
+        out = [z, c]
+        if return_first_stage_outputs:
+            xrec = self.decode_first_stage(z)
+            out.extend([x, xrec])
+        if return_original_cond:
+            out.append(xc)
+        return out
+
+    @torch.no_grad()
+    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+        if predict_cids:
+            if z.dim() == 4:
+                z = torch.argmax(z.exp(), dim=1).long()
+            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+            z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+        z = 1. / self.scale_factor * z
+
+        if hasattr(self, "split_input_params"):
+            if self.split_input_params["patch_distributed_vq"]:
+                ks = self.split_input_params["ks"]  # eg. (128, 128)
+                stride = self.split_input_params["stride"]  # eg. (64, 64)
+                uf = self.split_input_params["vqf"]
+                bs, nc, h, w = z.shape
+                if ks[0] > h or ks[1] > w:
+                    ks = (min(ks[0], h), min(ks[1], w))
+                    print("reducing Kernel")
+
+                if stride[0] > h or stride[1] > w:
+                    stride = (min(stride[0], h), min(stride[1], w))
+                    print("reducing stride")
+
+                fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+                z = unfold(z)  # (bn, nc * prod(**ks), L)
+                # 1. Reshape to img shape
+                z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
+
+                # 2. apply model loop over last dim
+                if isinstance(self.first_stage_model, VQModelInterface):
+                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+                                                                 force_not_quantize=predict_cids or force_not_quantize)
+                                   for i in range(z.shape[-1])]
+                else:
+
+                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+                                   for i in range(z.shape[-1])]
+
+                o = torch.stack(output_list, axis=-1)  # # (bn, nc, ks[0], ks[1], L)
+                o = o * weighting
+                # Reverse 1. reshape to img shape
+                o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)
+                # stitch crops together
+                decoded = fold(o)
+                decoded = decoded / normalization  # norm is shape (1, 1, h, w)
+                return decoded
+            else:
+                if isinstance(self.first_stage_model, VQModelInterface):
+                    return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+                else:
+                    return self.first_stage_model.decode(z)
+
+        else:
+            if isinstance(self.first_stage_model, VQModelInterface):
+                return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+            else:
+                return self.first_stage_model.decode(z)
+
+    # same as above but without decorator
+    def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+        if predict_cids:
+            if z.dim() == 4:
+                z = torch.argmax(z.exp(), dim=1).long()
+            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+            z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+        z = 1. / self.scale_factor * z
+
+        if hasattr(self, "split_input_params"):
+            if self.split_input_params["patch_distributed_vq"]:
+                ks = self.split_input_params["ks"]  # eg. (128, 128)
+                stride = self.split_input_params["stride"]  # eg. (64, 64)
+                uf = self.split_input_params["vqf"]
+                bs, nc, h, w = z.shape
+                if ks[0] > h or ks[1] > w:
+                    ks = (min(ks[0], h), min(ks[1], w))
+                    print("reducing Kernel")
+
+                if stride[0] > h or stride[1] > w:
+                    stride = (min(stride[0], h), min(stride[1], w))
+                    print("reducing stride")
+
+                fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+                z = unfold(z)  # (bn, nc * prod(**ks), L)
+                # 1. Reshape to img shape
+                z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
+
+                # 2. apply model loop over last dim
+                if isinstance(self.first_stage_model, VQModelInterface):  
+                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+                                                                 force_not_quantize=predict_cids or force_not_quantize)
+                                   for i in range(z.shape[-1])]
+                else:
+
+                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+                                   for i in range(z.shape[-1])]
+
+                o = torch.stack(output_list, axis=-1)  # # (bn, nc, ks[0], ks[1], L)
+                o = o * weighting
+                # Reverse 1. reshape to img shape
+                o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)
+                # stitch crops together
+                decoded = fold(o)
+                decoded = decoded / normalization  # norm is shape (1, 1, h, w)
+                return decoded
+            else:
+                if isinstance(self.first_stage_model, VQModelInterface):
+                    return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+                else:
+                    return self.first_stage_model.decode(z)
+
+        else:
+            if isinstance(self.first_stage_model, VQModelInterface):
+                return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+            else:
+                return self.first_stage_model.decode(z)
+
+    @torch.no_grad()
+    def encode_first_stage(self, x):
+        if hasattr(self, "split_input_params"):
+            if self.split_input_params["patch_distributed_vq"]:
+                ks = self.split_input_params["ks"]  # eg. (128, 128)
+                stride = self.split_input_params["stride"]  # eg. (64, 64)
+                df = self.split_input_params["vqf"]
+                self.split_input_params['original_image_size'] = x.shape[-2:]
+                bs, nc, h, w = x.shape
+                if ks[0] > h or ks[1] > w:
+                    ks = (min(ks[0], h), min(ks[1], w))
+                    print("reducing Kernel")
+
+                if stride[0] > h or stride[1] > w:
+                    stride = (min(stride[0], h), min(stride[1], w))
+                    print("reducing stride")
+
+                fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
+                z = unfold(x)  # (bn, nc * prod(**ks), L)
+                # Reshape to img shape
+                z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
+
+                output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
+                               for i in range(z.shape[-1])]
+
+                o = torch.stack(output_list, axis=-1)
+                o = o * weighting
+
+                # Reverse reshape to img shape
+                o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)
+                # stitch crops together
+                decoded = fold(o)
+                decoded = decoded / normalization
+                return decoded
+
+            else:
+                return self.first_stage_model.encode(x)
+        else:
+            return self.first_stage_model.encode(x)
+
+    def shared_step(self, batch, **kwargs):
+        x, c = self.get_input(batch, self.first_stage_key)
+        loss = self(x, c)
+        return loss
+
+    def forward(self, x, c, *args, **kwargs):
+        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+        if self.model.conditioning_key is not None:
+            assert c is not None
+            if self.cond_stage_trainable:
+                c = self.get_learned_conditioning(c)
+            if self.shorten_cond_schedule:  # TODO: drop this option
+                tc = self.cond_ids[t].to(self.device)
+                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+        return self.p_losses(x, c, t, *args, **kwargs)
+
+    def _rescale_annotations(self, bboxes, crop_coordinates):  # TODO: move to dataset
+        def rescale_bbox(bbox):
+            x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+            y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+            w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+            h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+            return x0, y0, w, h
+
+        return [rescale_bbox(b) for b in bboxes]
+
+    def apply_model(self, x_noisy, t, cond, return_ids=False):
+
+        if isinstance(cond, dict):
+            # hybrid case, cond is exptected to be a dict
+            pass
+        else:
+            if not isinstance(cond, list):
+                cond = [cond]
+            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+            cond = {key: cond}
+
+        if hasattr(self, "split_input_params"):
+            assert len(cond) == 1  # todo can only deal with one conditioning atm
+            assert not return_ids  
+            ks = self.split_input_params["ks"]  # eg. (128, 128)
+            stride = self.split_input_params["stride"]  # eg. (64, 64)
+
+            h, w = x_noisy.shape[-2:]
+
+            fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
+
+            z = unfold(x_noisy)  # (bn, nc * prod(**ks), L)
+            # Reshape to img shape
+            z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
+            z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
+
+            if self.cond_stage_key in ["image", "LR_image", "segmentation",
+                                       'bbox_img'] and self.model.conditioning_key:  # todo check for completeness
+                c_key = next(iter(cond.keys()))  # get key
+                c = next(iter(cond.values()))  # get value
+                assert (len(c) == 1)  # todo extend to list with more than one elem
+                c = c[0]  # get element
+
+                c = unfold(c)
+                c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
+
+                cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
+
+            elif self.cond_stage_key == 'coordinates_bbox':
+                assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
+
+                # assuming padding of unfold is always 0 and its dilation is always 1
+                n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
+                full_img_h, full_img_w = self.split_input_params['original_image_size']
+                # as we are operating on latents, we need the factor from the original image size to the
+                # spatial latent size to properly rescale the crops for regenerating the bbox annotations
+                num_downs = self.first_stage_model.encoder.num_resolutions - 1
+                rescale_latent = 2 ** (num_downs)
+
+                # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
+                # need to rescale the tl patch coordinates to be in between (0,1)
+                tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
+                                         rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
+                                        for patch_nr in range(z.shape[-1])]
+
+                # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
+                patch_limits = [(x_tl, y_tl,
+                                 rescale_latent * ks[0] / full_img_w,
+                                 rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
+                # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
+
+                # tokenize crop coordinates for the bounding boxes of the respective patches
+                patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
+                                      for bbox in patch_limits]  # list of length l with tensors of shape (1, 2)
+                print(patch_limits_tknzd[0].shape)
+                # cut tknzd crop position from conditioning
+                assert isinstance(cond, dict), 'cond must be dict to be fed into model'
+                cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
+                print(cut_cond.shape)
+
+                adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
+                adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
+                print(adapted_cond.shape)
+                adapted_cond = self.get_learned_conditioning(adapted_cond)
+                print(adapted_cond.shape)
+                adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
+                print(adapted_cond.shape)
+
+                cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
+
+            else:
+                cond_list = [cond for i in range(z.shape[-1])]  # Todo make this more efficient
+
+            # apply model by loop over crops
+            output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
+            assert not isinstance(output_list[0],
+                                  tuple)  # todo cant deal with multiple model outputs check this never happens
+
+            o = torch.stack(output_list, axis=-1)
+            o = o * weighting
+            # Reverse reshape to img shape
+            o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)
+            # stitch crops together
+            x_recon = fold(o) / normalization
+
+        else:
+            x_recon = self.model(x_noisy, t, **cond)
+
+        if isinstance(x_recon, tuple) and not return_ids:
+            return x_recon[0]
+        else:
+            return x_recon
+
+    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+        return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+               extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+    def _prior_bpd(self, x_start):
+        """
+        Get the prior KL term for the variational lower-bound, measured in
+        bits-per-dim.
+        This term can't be optimized, as it only depends on the encoder.
+        :param x_start: the [N x C x ...] tensor of inputs.
+        :return: a batch of [N] KL values (in bits), one per batch element.
+        """
+        batch_size = x_start.shape[0]
+        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+        kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+        return mean_flat(kl_prior) / np.log(2.0)
+
+    def p_losses(self, x_start, cond, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+        model_output = self.apply_model(x_noisy, t, cond)
+
+        loss_dict = {}
+        prefix = 'train' if self.training else 'val'
+
+        if self.parameterization == "x0":
+            target = x_start
+        elif self.parameterization == "eps":
+            target = noise
+        else:
+            raise NotImplementedError()
+
+        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+        logvar_t = self.logvar[t].to(self.device)
+        loss = loss_simple / torch.exp(logvar_t) + logvar_t
+        # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+        if self.learn_logvar:
+            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+            loss_dict.update({'logvar': self.logvar.data.mean()})
+
+        loss = self.l_simple_weight * loss.mean()
+
+        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+        loss += (self.original_elbo_weight * loss_vlb)
+        loss_dict.update({f'{prefix}/loss': loss})
+
+        return loss, loss_dict
+
+    def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+                        return_x0=False, score_corrector=None, corrector_kwargs=None):
+        t_in = t
+        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+        if score_corrector is not None:
+            assert self.parameterization == "eps"
+            model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+        if return_codebook_ids:
+            model_out, logits = model_out
+
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        else:
+            raise NotImplementedError()
+
+        if clip_denoised:
+            x_recon.clamp_(-1., 1.)
+        if quantize_denoised:
+            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+        if return_codebook_ids:
+            return model_mean, posterior_variance, posterior_log_variance, logits
+        elif return_x0:
+            return model_mean, posterior_variance, posterior_log_variance, x_recon
+        else:
+            return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+                 return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+                 temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+        b, *_, device = *x.shape, x.device
+        outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+                                       return_codebook_ids=return_codebook_ids,
+                                       quantize_denoised=quantize_denoised,
+                                       return_x0=return_x0,
+                                       score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+        if return_codebook_ids:
+            raise DeprecationWarning("Support dropped.")
+            model_mean, _, model_log_variance, logits = outputs
+        elif return_x0:
+            model_mean, _, model_log_variance, x0 = outputs
+        else:
+            model_mean, _, model_log_variance = outputs
+
+        noise = noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+        if return_codebook_ids:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+        if return_x0:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+        else:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+                              log_every_t=None):
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        timesteps = self.num_timesteps
+        if batch_size is not None:
+            b = batch_size if batch_size is not None else shape[0]
+            shape = [batch_size] + list(shape)
+        else:
+            b = batch_size = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=self.device)
+        else:
+            img = x_T
+        intermediates = []
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+            else:
+                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+                        total=timesteps) if verbose else reversed(
+            range(0, timesteps))
+        if type(temperature) == float:
+            temperature = [temperature] * timesteps
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != 'hybrid'
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img, x0_partial = self.p_sample(img, cond, ts,
+                                            clip_denoised=self.clip_denoised,
+                                            quantize_denoised=quantize_denoised, return_x0=True,
+                                            temperature=temperature[i], noise_dropout=noise_dropout,
+                                            score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1. - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(x0_partial)
+            if callback: callback(i)
+            if img_callback: img_callback(img, i)
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_loop(self, cond, shape, return_intermediates=False,
+                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, start_T=None,
+                      log_every_t=None):
+
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        device = self.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        intermediates = [img]
+        if timesteps is None:
+            timesteps = self.num_timesteps
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+            range(0, timesteps))
+
+        if mask is not None:
+            assert x0 is not None
+            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=device, dtype=torch.long)
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != 'hybrid'
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img = self.p_sample(img, cond, ts,
+                                clip_denoised=self.clip_denoised,
+                                quantize_denoised=quantize_denoised)
+            if mask is not None:
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1. - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(img)
+            if callback: callback(i)
+            if img_callback: img_callback(img, i)
+
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+               verbose=True, timesteps=None, quantize_denoised=False,
+               mask=None, x0=None, shape=None,**kwargs):
+        if shape is None:
+            shape = (batch_size, self.channels, self.image_size, self.image_size)
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+            else:
+                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+        return self.p_sample_loop(cond,
+                                  shape,
+                                  return_intermediates=return_intermediates, x_T=x_T,
+                                  verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+                                  mask=mask, x0=x0)
+
+    @torch.no_grad()
+    def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
+
+        if ddim:
+            ddim_sampler = DDIMSampler(self)
+            shape = (self.channels, self.image_size, self.image_size)
+            samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
+                                                        shape,cond,verbose=False,**kwargs)
+
+        else:
+            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+                                                 return_intermediates=True,**kwargs)
+
+        return samples, intermediates
+
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+                   plot_diffusion_rows=True, **kwargs):
+
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+                                           return_first_stage_outputs=True,
+                                           force_c_encode=True,
+                                           return_original_cond=True,
+                                           bs=N)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption"]:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
+                log["conditioning"] = xc
+            elif self.cond_stage_key == 'class_label':
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
+                log['conditioning'] = xc
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with self.ema_scope("Plotting"):
+                samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+                                                         ddim_steps=ddim_steps,eta=ddim_eta)
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+            if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+                    self.first_stage_model, IdentityFirstStage):
+                # also display when quantizing x0 while sampling
+                with self.ema_scope("Plotting Quantized Denoised"):
+                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+                                                             ddim_steps=ddim_steps,eta=ddim_eta,
+                                                             quantize_denoised=True)
+                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+                    #                                      quantize_denoised=True)
+                x_samples = self.decode_first_stage(samples.to(self.device))
+                log["samples_x0_quantized"] = x_samples
+
+            if inpaint:
+                # make a simple center square
+                b, h, w = z.shape[0], z.shape[2], z.shape[3]
+                mask = torch.ones(N, h, w).to(self.device)
+                # zeros will be filled in
+                mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+                mask = mask[:, None, ...]
+                with self.ema_scope("Plotting Inpaint"):
+
+                    samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
+                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+                x_samples = self.decode_first_stage(samples.to(self.device))
+                log["samples_inpainting"] = x_samples
+                log["mask"] = mask
+
+                # outpaint
+                with self.ema_scope("Plotting Outpaint"):
+                    samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
+                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+                x_samples = self.decode_first_stage(samples.to(self.device))
+                log["samples_outpainting"] = x_samples
+
+        if plot_progressive_rows:
+            with self.ema_scope("Plotting Progressives"):
+                img, progressives = self.progressive_denoising(c,
+                                                               shape=(self.channels, self.image_size, self.image_size),
+                                                               batch_size=N)
+            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+            log["progressive_row"] = prog_row
+
+        if return_keys:
+            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+                return log
+            else:
+                return {key: log[key] for key in return_keys}
+        return log
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+        if self.cond_stage_trainable:
+            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+            params = params + list(self.cond_stage_model.parameters())
+        if self.learn_logvar:
+            print('Diffusion model optimizing logvar')
+            params.append(self.logvar)
+        opt = torch.optim.AdamW(params, lr=lr)
+        if self.use_scheduler:
+            assert 'target' in self.scheduler_config
+            scheduler = instantiate_from_config(self.scheduler_config)
+
+            print("Setting up LambdaLR scheduler...")
+            scheduler = [
+                {
+                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+                    'interval': 'step',
+                    'frequency': 1
+                }]
+            return [opt], scheduler
+        return opt
+
+    @torch.no_grad()
+    def to_rgb(self, x):
+        x = x.float()
+        if not hasattr(self, "colorize"):
+            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+        x = nn.functional.conv2d(x, weight=self.colorize)
+        x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+        return x
+
+
+class DiffusionWrapperV1(pl.LightningModule):
+    def __init__(self, diff_model_config, conditioning_key):
+        super().__init__()
+        self.diffusion_model = instantiate_from_config(diff_model_config)
+        self.conditioning_key = conditioning_key
+        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
+
+    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
+        if self.conditioning_key is None:
+            out = self.diffusion_model(x, t)
+        elif self.conditioning_key == 'concat':
+            xc = torch.cat([x] + c_concat, dim=1)
+            out = self.diffusion_model(xc, t)
+        elif self.conditioning_key == 'crossattn':
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(x, t, context=cc)
+        elif self.conditioning_key == 'hybrid':
+            xc = torch.cat([x] + c_concat, dim=1)
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(xc, t, context=cc)
+        elif self.conditioning_key == 'adm':
+            cc = c_crossattn[0]
+            out = self.diffusion_model(x, t, y=cc)
+        else:
+            raise NotImplementedError()
+
+        return out
+
+
+class Layout2ImgDiffusionV1(LatentDiffusionV1):
+    # TODO: move all layout-specific hacks to this class
+    def __init__(self, cond_stage_key, *args, **kwargs):
+        assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
+        super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+
+    def log_images(self, batch, N=8, *args, **kwargs):
+        logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+
+        key = 'train' if self.training else 'validation'
+        dset = self.trainer.datamodule.datasets[key]
+        mapper = dset.conditional_builders[self.cond_stage_key]
+
+        bbox_imgs = []
+        map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
+        for tknzd_bbox in batch[self.cond_stage_key][:N]:
+            bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
+            bbox_imgs.append(bboximg)
+
+        cond_img = torch.stack(bbox_imgs, dim=0)
+        logs['bbox_image'] = cond_img
+        return logs
+
+setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
+setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
+setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
+setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
diff --git a/extensions-builtin/ScuNET/preload.py b/extensions-builtin/ScuNET/preload.py
new file mode 100644
index 00000000..f12c5b90
--- /dev/null
+++ b/extensions-builtin/ScuNET/preload.py
@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+    parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
diff --git a/modules/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
similarity index 95%
rename from modules/scunet_model.py
rename to extensions-builtin/ScuNET/scripts/scunet_model.py
index 36a996bf..e0fbf3a3 100644
--- a/modules/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url
 
 import modules.upscaler
 from modules import devices, modelloader
-from modules.scunet_model_arch import SCUNet as net
+from scunet_model_arch import SCUNet as net
 
 
 class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -49,14 +49,13 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
         if model is None:
             return img
 
-        device = devices.device_scunet
+        device = devices.get_device_for('scunet')
         img = np.array(img)
         img = img[:, :, ::-1]
         img = np.moveaxis(img, 2, 0) / 255
         img = torch.from_numpy(img).float()
         img = img.unsqueeze(0).to(device)
 
-        img = img.to(device)
         with torch.no_grad():
             output = model(img)
         output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
@@ -67,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
         return PIL.Image.fromarray(output, 'RGB')
 
     def load_model(self, path: str):
-        device = devices.device_scunet
+        device = devices.get_device_for('scunet')
         if "http" in path:
             filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
                                           progress=True)
diff --git a/modules/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py
similarity index 100%
rename from modules/scunet_model_arch.py
rename to extensions-builtin/ScuNET/scunet_model_arch.py
diff --git a/extensions-builtin/SwinIR/preload.py b/extensions-builtin/SwinIR/preload.py
new file mode 100644
index 00000000..567e44bc
--- /dev/null
+++ b/extensions-builtin/SwinIR/preload.py
@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+    parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
diff --git a/modules/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
similarity index 80%
rename from modules/swinir_model.py
rename to extensions-builtin/SwinIR/scripts/swinir_model.py
index baa02e3d..9a74b253 100644
--- a/modules/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -7,15 +7,14 @@ from PIL import Image
 from basicsr.utils.download_util import load_file_from_url
 from tqdm import tqdm
 
-from modules import modelloader
-from modules.shared import cmd_opts, opts, device
-from modules.swinir_model_arch import SwinIR as net
-from modules.swinir_model_arch_v2 import Swin2SR as net2
+from modules import modelloader, devices, script_callbacks, shared
+from modules.shared import cmd_opts, opts
+from swinir_model_arch import SwinIR as net
+from swinir_model_arch_v2 import Swin2SR as net2
 from modules.upscaler import Upscaler, UpscalerData
 
-precision_scope = (
-    torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
-)
+
+device_swinir = devices.get_device_for('swinir')
 
 
 class UpscalerSwinIR(Upscaler):
@@ -42,7 +41,7 @@ class UpscalerSwinIR(Upscaler):
         model = self.load_model(model_file)
         if model is None:
             return img
-        model = model.to(device)
+        model = model.to(device_swinir, dtype=devices.dtype)
         img = upscale(img, model)
         try:
             torch.cuda.empty_cache()
@@ -94,25 +93,27 @@ class UpscalerSwinIR(Upscaler):
             model.load_state_dict(pretrained_model[params], strict=True)
         else:
             model.load_state_dict(pretrained_model, strict=True)
-        if not cmd_opts.no_half:
-            model = model.half()
         return model
 
 
 def upscale(
         img,
         model,
-        tile=opts.SWIN_tile,
-        tile_overlap=opts.SWIN_tile_overlap,
+        tile=None,
+        tile_overlap=None,
         window_size=8,
         scale=4,
 ):
+    tile = tile or opts.SWIN_tile
+    tile_overlap = tile_overlap or opts.SWIN_tile_overlap
+
+
     img = np.array(img)
     img = img[:, :, ::-1]
     img = np.moveaxis(img, 2, 0) / 255
     img = torch.from_numpy(img).float()
-    img = img.unsqueeze(0).to(device)
-    with torch.no_grad(), precision_scope("cuda"):
+    img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
+    with torch.no_grad(), devices.autocast():
         _, _, h_old, w_old = img.size()
         h_pad = (h_old // window_size + 1) * window_size - h_old
         w_pad = (w_old // window_size + 1) * window_size - w_old
@@ -139,8 +140,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
     stride = tile - tile_overlap
     h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
     w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
-    E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
-    W = torch.zeros_like(E, dtype=torch.half, device=device)
+    E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
+    W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
 
     with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
         for h_idx in h_idx_list:
@@ -159,3 +160,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
     output = E.div_(W)
 
     return output
+
+
+def on_ui_settings():
+    import gradio as gr
+
+    shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
+    shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/modules/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py
similarity index 100%
rename from modules/swinir_model_arch.py
rename to extensions-builtin/SwinIR/swinir_model_arch.py
diff --git a/modules/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
similarity index 100%
rename from modules/swinir_model_arch_v2.py
rename to extensions-builtin/SwinIR/swinir_model_arch_v2.py
diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
new file mode 100644
index 00000000..eccfb0f9
--- /dev/null
+++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
@@ -0,0 +1,107 @@
+// Stable Diffusion WebUI - Bracket checker
+// Version 1.0
+// By Hingashi no Florin/Bwin4L
+// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
+// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
+
+function checkBrackets(evt) {
+  textArea = evt.target;
+  tabName = evt.target.parentElement.parentElement.id.split("_")[0];
+  counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
+
+  promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
+
+  errorStringParen  = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
+  errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
+  errorStringCurly  = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
+
+  openBracketRegExp = /\(/g;
+  closeBracketRegExp = /\)/g;
+
+  openSquareBracketRegExp = /\[/g;
+  closeSquareBracketRegExp = /\]/g;
+
+  openCurlyBracketRegExp = /\{/g;
+  closeCurlyBracketRegExp = /\}/g;
+
+  totalOpenBracketMatches = 0;
+  totalCloseBracketMatches = 0;
+  totalOpenSquareBracketMatches = 0;
+  totalCloseSquareBracketMatches = 0;
+  totalOpenCurlyBracketMatches = 0;
+  totalCloseCurlyBracketMatches = 0;
+
+  openBracketMatches = textArea.value.match(openBracketRegExp);
+  if(openBracketMatches) {
+    totalOpenBracketMatches = openBracketMatches.length;
+  }
+
+  closeBracketMatches = textArea.value.match(closeBracketRegExp);
+  if(closeBracketMatches) {
+    totalCloseBracketMatches = closeBracketMatches.length;
+  }
+
+  openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
+  if(openSquareBracketMatches) {
+    totalOpenSquareBracketMatches = openSquareBracketMatches.length;
+  }
+
+  closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
+  if(closeSquareBracketMatches) {
+    totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
+  }
+
+  openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
+  if(openCurlyBracketMatches) {
+    totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
+  }
+
+  closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
+  if(closeCurlyBracketMatches) {
+    totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
+  }
+
+  if(totalOpenBracketMatches != totalCloseBracketMatches) {
+    if(!counterElt.title.includes(errorStringParen)) {
+      counterElt.title += errorStringParen;
+    }
+  } else {
+    counterElt.title = counterElt.title.replace(errorStringParen, '');
+  }
+
+  if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
+    if(!counterElt.title.includes(errorStringSquare)) {
+      counterElt.title += errorStringSquare;
+    }
+  } else {
+    counterElt.title = counterElt.title.replace(errorStringSquare, '');
+  }
+
+  if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
+    if(!counterElt.title.includes(errorStringCurly)) {
+      counterElt.title += errorStringCurly;
+    }
+  } else {
+    counterElt.title = counterElt.title.replace(errorStringCurly, '');
+  }
+
+  if(counterElt.title != '') {
+    counterElt.style = 'color: #FF5555;';
+  } else {
+    counterElt.style = '';
+  }
+}
+
+var shadowRootLoaded = setInterval(function() {
+  var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
+  if(shadowTextArea.length < 1) {
+      return false;
+  }
+
+  clearInterval(shadowRootLoaded);
+
+  document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
+  document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
+  document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
+  document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
+}, 1000);
diff --git a/extensions-builtin/roll-artist/scripts/roll-artist.py b/extensions-builtin/roll-artist/scripts/roll-artist.py
new file mode 100644
index 00000000..c3bc1fd0
--- /dev/null
+++ b/extensions-builtin/roll-artist/scripts/roll-artist.py
@@ -0,0 +1,50 @@
+import random
+
+from modules import script_callbacks, shared
+import gradio as gr
+
+art_symbol = '\U0001f3a8'  # 🎨
+global_prompt = None
+related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
+
+
+def roll_artist(prompt):
+    allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
+    artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
+
+    return prompt + ", " + artist.name if prompt != '' else artist.name
+
+
+def add_roll_button(prompt):
+    roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
+
+    roll.click(
+        fn=roll_artist,
+        _js="update_txt2img_tokens",
+        inputs=[
+            prompt,
+        ],
+        outputs=[
+            prompt,
+        ]
+    )
+
+
+def after_component(component, **kwargs):
+    global global_prompt
+
+    elem_id = kwargs.get('elem_id', None)
+    if elem_id not in related_ids:
+        return
+
+    if elem_id == "txt2img_prompt":
+        global_prompt = component
+    elif elem_id == "txt2img_clear_prompt":
+        add_roll_button(global_prompt)
+    elif elem_id == "img2img_prompt":
+        global_prompt = component
+    elif elem_id == "img2img_clear_prompt":
+        add_roll_button(global_prompt)
+
+
+script_callbacks.on_after_component(after_component)
diff --git a/extensions/put extensions here.txt b/extensions/put extensions here.txt
new file mode 100644
index 00000000..e69de29b
diff --git a/html/footer.html b/html/footer.html
new file mode 100644
index 00000000..a8f2adf7
--- /dev/null
+++ b/html/footer.html
@@ -0,0 +1,9 @@
+<div>
+        <a href="/docs">API</a>
+         • 
+        <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
+         • 
+        <a href="https://gradio.app">Gradio</a>
+         • 
+        <a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
+</div>
diff --git a/html/licenses.html b/html/licenses.html
new file mode 100644
index 00000000..9eeaa072
--- /dev/null
+++ b/html/licenses.html
@@ -0,0 +1,392 @@
+<style>
+    #licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}
+    #licenses small {font-size: 0.95em; opacity: 0.85;}
+    #licenses pre { margin: 1em 0 2em 0;}
+</style>
+
+<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
+<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
+<pre>
+S-Lab License 1.0
+
+Copyright 2022 S-Lab
+
+Redistribution and use for non-commercial purpose in source and
+binary forms, with or without modification, are permitted provided
+that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in
+   the documentation and/or other materials provided with the
+   distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+   contributors may be used to endorse or promote products derived
+   from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+In the event that redistribution and/or use for commercial purpose in
+source or binary forms, with or without modification is required,
+please contact the contributor(s) of the work.
+</pre>
+
+
+<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
+<small>Code for architecture and reading models copied.</small>
+<pre>
+MIT License
+
+Copyright (c) 2021 victorca25
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+</pre>
+
+<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
+<small>Some code is copied to support ESRGAN models.</small>
+<pre>
+BSD 3-Clause License
+
+Copyright (c) 2021, Xintao Wang
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+   list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+   this list of conditions and the following disclaimer in the documentation
+   and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+   contributors may be used to endorse or promote products derived from
+   this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+</pre>
+
+<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
+<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
+<pre>
+MIT License
+
+Copyright (c) 2022 InvokeAI Team
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+</pre>
+
+<h2><a href="https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE">LDSR</a></h2>
+<small>Code added by contirubtors, most likely copied from this repository.</small>
+<pre>
+MIT License
+
+Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+</pre>
+
+<h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE">CLIP Interrogator</a></h2>
+<small>Some small amounts of code borrowed and reworked.</small>
+<pre>
+MIT License
+
+Copyright (c) 2022 pharmapsychotic
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+</pre>
+
+<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
+<small>Code added by contirubtors, most likely copied from this repository.</small>
+
+<pre>
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [2021] [SwinIR Authors]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+</pre>
+
diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js
index 96f1c00d..66f26a22 100644
--- a/javascript/aspectRatioOverlay.js
+++ b/javascript/aspectRatioOverlay.js
@@ -3,12 +3,12 @@ let currentWidth = null;
 let currentHeight = null;
 let arFrameTimeout = setTimeout(function(){},0);
 
-function dimensionChange(e,dimname){
+function dimensionChange(e, is_width, is_height){
 
-	if(dimname == 'Width'){
+	if(is_width){
 		currentWidth = e.target.value*1.0
 	}
-	if(dimname == 'Height'){
+	if(is_height){
 		currentHeight = e.target.value*1.0
 	}
 
@@ -18,22 +18,13 @@ function dimensionChange(e,dimname){
 		return;
 	}
 
-	var img2imgMode = gradioApp().querySelector('#mode_img2img.tabs > div > button.rounded-t-lg.border-gray-200')
-	if(img2imgMode){
-		img2imgMode=img2imgMode.innerText
-	}else{
-		return;
-	}
-
-	var redrawImage = gradioApp().querySelector('div[data-testid=image] img');
-	var inpaintImage = gradioApp().querySelector('#img2maskimg div[data-testid=image] img')
-
 	var targetElement = null;
 
-	if(img2imgMode=='img2img' && redrawImage){
-		targetElement = redrawImage;
-	}else if(img2imgMode=='Inpaint' && inpaintImage){
-		targetElement = inpaintImage;
+    var tabIndex = get_tab_index('mode_img2img')
+	if(tabIndex == 0){
+		targetElement = gradioApp().querySelector('div[data-testid=image] img');
+	} else if(tabIndex == 1){
+		targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
 	}
 
 	if(targetElement){
@@ -98,22 +89,20 @@ onUiUpdate(function(){
 	var inImg2img   = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
 	if(inImg2img){
 		let inputs = gradioApp().querySelectorAll('input');
-		inputs.forEach(function(e){ 
-			let parentLabel = e.parentElement.querySelector('label')
-			if(parentLabel && parentLabel.innerText){
-				if(!e.classList.contains('scrollwatch')){
-					if(parentLabel.innerText == 'Width' || parentLabel.innerText == 'Height'){
-						e.addEventListener('input', function(e){dimensionChange(e,parentLabel.innerText)} )
-						e.classList.add('scrollwatch')
-					}
-					if(parentLabel.innerText == 'Width'){
-						currentWidth = e.value*1.0
-					}
-					if(parentLabel.innerText == 'Height'){
-						currentHeight = e.value*1.0
-					}
-				}
-			} 
+		inputs.forEach(function(e){
+		    var is_width = e.parentElement.id == "img2img_width"
+		    var is_height = e.parentElement.id == "img2img_height"
+
+			if((is_width || is_height) && !e.classList.contains('scrollwatch')){
+				e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
+				e.classList.add('scrollwatch')
+			}
+			if(is_width){
+				currentWidth = e.value*1.0
+			}
+			if(is_height){
+				currentHeight = e.value*1.0
+			}
 		})
 	}
 });
diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js
index fe67c42e..11bcce1b 100644
--- a/javascript/contextMenus.js
+++ b/javascript/contextMenus.js
@@ -9,7 +9,7 @@ contextMenuInit = function(){
 
   function showContextMenu(event,element,menuEntries){
     let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
-    let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; 
+    let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
 
     let oldMenu = gradioApp().querySelector('#context-menu')
     if(oldMenu){
@@ -61,15 +61,15 @@ contextMenuInit = function(){
 
   }
 
-  function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){
-    
-    currentItems = menuSpecs.get(targetEmementSelector)
-    
+  function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
+
+    currentItems = menuSpecs.get(targetElementSelector)
+
     if(!currentItems){
       currentItems = []
-      menuSpecs.set(targetEmementSelector,currentItems);
+      menuSpecs.set(targetElementSelector,currentItems);
     }
-    let newItem = {'id':targetEmementSelector+'_'+uid(), 
+    let newItem = {'id':targetElementSelector+'_'+uid(),
                    'name':entryName,
                    'func':entryFunction,
                    'isNew':true}
@@ -97,7 +97,7 @@ contextMenuInit = function(){
       if(source.id && source.id.indexOf('check_progress')>-1){
         return
       }
-      
+
       let oldMenu = gradioApp().querySelector('#context-menu')
       if(oldMenu){
         oldMenu.remove()
@@ -117,7 +117,7 @@ contextMenuInit = function(){
       })
     });
     eventListenerApplied=true
-  
+
   }
 
   return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
@@ -152,8 +152,8 @@ addContextMenuEventListener = initResponse[2];
     generateOnRepeat('#img2img_generate','#img2img_interrupt');
   })
 
-  let cancelGenerateForever = function(){ 
-    clearInterval(window.generateOnRepeatInterval) 
+  let cancelGenerateForever = function(){
+    clearInterval(window.generateOnRepeatInterval)
   }
 
   appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
@@ -162,7 +162,7 @@ addContextMenuEventListener = initResponse[2];
   appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
 
   appendContextMenuOption('#roll','Roll three',
-    function(){ 
+    function(){
       let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
       setTimeout(function(){rollbutton.click()},100)
       setTimeout(function(){rollbutton.click()},200)
diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js
index 070cf255..fe008924 100644
--- a/javascript/dragdrop.js
+++ b/javascript/dragdrop.js
@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) {
         return;
     }
 
+    const tmpFile = files[0];
+
     imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
     const callback = () => {
         const fileInput = imgWrap.querySelector('input[type="file"]');
         if ( fileInput ) {
-            fileInput.files = files;
+            if ( files.length === 0 ) {
+                files = new DataTransfer();
+                files.items.add(tmpFile);
+                fileInput.files = files.files;
+            } else {
+                fileInput.files = files;
+            }
             fileInput.dispatchEvent(new Event('change'));   
         }
     };
@@ -43,7 +51,7 @@ function dropReplaceImage( imgWrap, files ) {
 window.document.addEventListener('dragover', e => {
     const target = e.composedPath()[0];
     const imgWrap = target.closest('[data-testid="image"]');
-    if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) {
+    if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
         return;
     }
     e.stopPropagation();
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index c0d29a74..b947cbec 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -1,7 +1,6 @@
 addEventListener('keydown', (event) => {
 	let target = event.originalTarget || event.composedPath()[0];
-	if (!target.hasAttribute("placeholder")) return;
-	if (!target.placeholder.toLowerCase().includes("prompt")) return;
+	if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
 	if (! (event.metaKey || event.ctrlKey)) return;
 
 
diff --git a/javascript/extensions.js b/javascript/extensions.js
new file mode 100644
index 00000000..59179ca6
--- /dev/null
+++ b/javascript/extensions.js
@@ -0,0 +1,35 @@
+
+function extensions_apply(_, _){
+    disable = []
+    update = []
+    gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
+        if(x.name.startsWith("enable_") && ! x.checked)
+            disable.push(x.name.substr(7))
+
+        if(x.name.startsWith("update_") && x.checked)
+            update.push(x.name.substr(7))
+    })
+
+    restart_reload()
+
+    return [JSON.stringify(disable), JSON.stringify(update)]
+}
+
+function extensions_check(){
+    gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
+        x.innerHTML = "Loading..."
+    })
+
+    return []
+}
+
+function install_extension_from_index(button, url){
+    button.disabled = "disabled"
+    button.value = "Installing..."
+
+    textarea = gradioApp().querySelector('#extension_to_install textarea')
+    textarea.value = url
+	textarea.dispatchEvent(new Event("input", { bubbles: true }))
+
+    gradioApp().querySelector('#install_extension_button').click()
+}
diff --git a/javascript/generationParams.js b/javascript/generationParams.js
new file mode 100644
index 00000000..95f05093
--- /dev/null
+++ b/javascript/generationParams.js
@@ -0,0 +1,33 @@
+// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
+
+let txt2img_gallery, img2img_gallery, modal = undefined;
+onUiUpdate(function(){
+	if (!txt2img_gallery) {
+		txt2img_gallery = attachGalleryListeners("txt2img")
+	}
+	if (!img2img_gallery) {
+		img2img_gallery = attachGalleryListeners("img2img")
+	}
+	if (!modal) {
+		modal = gradioApp().getElementById('lightboxModal')
+		modalObserver.observe(modal,  { attributes : true, attributeFilter : ['style'] });
+	}
+});
+
+let modalObserver = new MutationObserver(function(mutations) {
+	mutations.forEach(function(mutationRecord) {
+		let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
+		if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
+			gradioApp().getElementById(selectedTab+"_generation_info_button").click()
+	});
+});
+
+function attachGalleryListeners(tab_name) {
+	gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
+	gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
+	gallery?.addEventListener('keydown', (e) => {
+		if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
+			gradioApp().getElementById(tab_name+"_generation_info_button").click()
+	});
+	return gallery;
+}
diff --git a/javascript/hints.js b/javascript/hints.js
index b98012f5..63e17e05 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -6,6 +6,7 @@ titles = {
 	"GFPGAN": "Restore low quality faces using GFPGAN neural network",
 	"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help",
 	"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
+	"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution", 
 
 	"Batch count": "How many batches of images to create",
 	"Batch size": "How many image to create in a single batch",
@@ -17,6 +18,7 @@ titles = {
     "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
     "\u{1f4c2}": "Open images output directory",
     "\u{1f4be}": "Save style",
+    "\U0001F5D1": "Clear prompt",
     "\u{1f4cb}": "Apply selected styles to current prompt",
 
     "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
@@ -62,8 +64,8 @@ titles = {
 
     "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
 
-    "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
-    "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
+    "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
+    "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
     "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
 
     "Loopback": "Process an image, use it as an input, repeat.",
@@ -75,6 +77,7 @@ titles = {
     "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
 
     "Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
+    "Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
 
     "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
 
@@ -91,6 +94,13 @@ titles = {
 
     "Weighted sum": "Result = A * (1 - M) + B * M",
     "Add difference": "Result = A + (B - C) * M",
+
+    "Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n   rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG:   0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
+
+    "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
+
+    "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
+    "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality."
 }
 
 
diff --git a/javascript/images_history.js b/javascript/images_history.js
deleted file mode 100644
index f7d052c3..00000000
--- a/javascript/images_history.js
+++ /dev/null
@@ -1,206 +0,0 @@
-var images_history_click_image = function(){
-    if (!this.classList.contains("transform")){        
-        var gallery = images_history_get_parent_by_class(this, "images_history_cantainor");
-        var buttons = gallery.querySelectorAll(".gallery-item");
-        var i = 0;
-        var hidden_list = [];
-        buttons.forEach(function(e){
-            if (e.style.display == "none"){
-                hidden_list.push(i);
-            }
-            i += 1;
-        })
-        if (hidden_list.length > 0){
-            setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
-        }        
-    }    
-    images_history_set_image_info(this); 
-}
-
-var images_history_click_tab = function(){
-    var tabs_box = gradioApp().getElementById("images_history_tab");
-    if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
-        gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
-        tabs_box.classList.add(this.getAttribute("tabname"))
-    }                
-}
-
-function images_history_disabled_del(){
-    gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
-        btn.setAttribute('disabled','disabled');
-    }); 
-}
-
-function images_history_get_parent_by_class(item, class_name){
-    var parent = item.parentElement;
-    while(!parent.classList.contains(class_name)){
-        parent = parent.parentElement;
-    }
-    return parent;  
-}
-
-function images_history_get_parent_by_tagname(item, tagname){
-    var parent = item.parentElement;
-    tagname = tagname.toUpperCase()
-    while(parent.tagName != tagname){
-        console.log(parent.tagName, tagname)
-        parent = parent.parentElement;
-    }  
-    return parent;
-}
-
-function images_history_hide_buttons(hidden_list, gallery){
-    var buttons = gallery.querySelectorAll(".gallery-item");
-    var num = 0;
-    buttons.forEach(function(e){
-        if (e.style.display == "none"){
-            num += 1;
-        }
-    });
-    if (num == hidden_list.length){
-        setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
-    } 
-    for( i in hidden_list){
-        buttons[hidden_list[i]].style.display = "none";
-    }    
-}
-
-function images_history_set_image_info(button){
-    var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item");
-    var index = -1;
-    var i = 0;
-    buttons.forEach(function(e){
-        if(e == button){
-            index = i;
-        }
-        if(e.style.display != "none"){
-            i += 1;
-        }        
-    });
-    var gallery = images_history_get_parent_by_class(button, "images_history_cantainor");
-    var set_btn = gallery.querySelector(".images_history_set_index");
-    var curr_idx = set_btn.getAttribute("img_index", index);  
-    if (curr_idx != index) {
-        set_btn.setAttribute("img_index", index);        
-        images_history_disabled_del();
-    }
-    set_btn.click();
-    
-}
-
-function images_history_get_current_img(tabname, image_path, files){
-    return [
-        gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"), 
-        image_path, 
-        files
-    ];
-}
-
-function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){
-    image_index = parseInt(image_index);
-    var tab = gradioApp().getElementById(tabname + '_images_history');
-    var set_btn = tab.querySelector(".images_history_set_index");
-    var buttons = [];
-    tab.querySelectorAll(".gallery-item").forEach(function(e){
-        if (e.style.display != 'none'){
-            buttons.push(e);
-        }
-    });    
-    var img_num = buttons.length / 2;
-    if (img_num <= del_num){
-        setTimeout(function(tabname){
-            gradioApp().getElementById(tabname + '_images_history_renew_page').click();
-        }, 30, tabname); 
-    } else {
-        var next_img  
-        for (var i = 0; i < del_num; i++){
-            if (image_index + i < image_index + img_num){
-                buttons[image_index + i].style.display = 'none';
-                buttons[image_index + img_num + 1].style.display = 'none';
-                next_img = image_index + i + 1
-            }
-        }
-        var bnt;
-        if (next_img  >= img_num){
-            btn = buttons[image_index - del_num];
-        } else {            
-            btn = buttons[next_img];          
-        } 
-        setTimeout(function(btn){btn.click()}, 30, btn);
-    }
-    images_history_disabled_del();  
-    return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
-}
-
-function images_history_turnpage(img_path, page_index, image_index, tabname){
-    var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
-    buttons.forEach(function(elem) {
-        elem.style.display = 'block';
-    })
-    return [img_path, page_index, image_index, tabname];
-}
-
-function images_history_enable_del_buttons(){
-    gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
-        btn.removeAttribute('disabled');
-    })
-}
-
-function images_history_init(){ 
-    var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page')
-    if (load_txt2img_button){        
-        for (var i in images_history_tab_list ){
-            tab = images_history_tab_list[i];
-            gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
-            gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
-            gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
-            gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");            
-                     
-        }
-        var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
-        tabs_box.setAttribute("id", "images_history_tab");        
-        var tab_btns = tabs_box.querySelectorAll("button");        
-        for (var i in images_history_tab_list){               
-            var tabname = images_history_tab_list[i]
-            tab_btns[i].setAttribute("tabname", tabname);
-
-            // this refreshes history upon tab switch
-            // until the history is known to work well, which is not the case now, we do not do this at startup
-            //tab_btns[i].addEventListener('click', images_history_click_tab);
-        }    
-        tabs_box.classList.add(images_history_tab_list[0]);
-
-        // same as above, at page load
-        //load_txt2img_button.click();
-    } else {
-        setTimeout(images_history_init, 500);
-    } 
-}
-
-var images_history_tab_list = ["txt2img", "img2img", "extras"];
-setTimeout(images_history_init, 500);
-document.addEventListener("DOMContentLoaded", function() {
-    var mutationObserver = new MutationObserver(function(m){
-        for (var i in images_history_tab_list ){
-            let tabname = images_history_tab_list[i]
-            var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
-            buttons.forEach(function(bnt){    
-                bnt.addEventListener('click', images_history_click_image, true);
-            });
-
-            // same as load_txt2img_button.click() above
-            /*
-            var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
-            if (cls_btn){
-                cls_btn.addEventListener('click', function(){
-                    gradioApp().getElementById(tabname + '_images_history_renew_page').click();
-                }, false);
-            }*/
-
-        }     
-    });
-    mutationObserver.observe( gradioApp(), { childList:true, subtree:true });
-
-});
-
-
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index 9e380c65..67916536 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -13,6 +13,15 @@ function showModal(event) {
     }
     lb.style.display = "block";
     lb.focus()
+
+    const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
+    const tabImg2Img = gradioApp().getElementById("tab_img2img")
+    // show the save button in modal only on txt2img or img2img tabs
+    if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
+        gradioApp().getElementById("modal_save").style.display = "inline"
+    } else {
+        gradioApp().getElementById("modal_save").style.display = "none"
+    }
     event.stopPropagation()
 }
 
@@ -81,6 +90,25 @@ function modalImageSwitch(offset) {
     }
 }
 
+function saveImage(){
+    const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
+    const tabImg2Img = gradioApp().getElementById("tab_img2img")
+    const saveTxt2Img = "save_txt2img"
+    const saveImg2Img = "save_img2img"
+    if (tabTxt2Img.style.display != "none") {
+        gradioApp().getElementById(saveTxt2Img).click()
+    } else if (tabImg2Img.style.display != "none") {
+        gradioApp().getElementById(saveImg2Img).click()
+    } else {
+        console.error("missing implementation for saving modal of this type")
+    }
+}
+
+function modalSaveImage(event) {
+    saveImage()
+    event.stopPropagation()
+}
+
 function modalNextImage(event) {
     modalImageSwitch(1)
     event.stopPropagation()
@@ -93,6 +121,9 @@ function modalPrevImage(event) {
 
 function modalKeyHandler(event) {
     switch (event.key) {
+        case "s":
+            saveImage()
+            break;
         case "ArrowLeft":
             modalPrevImage(event)
             break;
@@ -198,6 +229,14 @@ document.addEventListener("DOMContentLoaded", function() {
     modalTileImage.title = "Preview tiling";
     modalControls.appendChild(modalTileImage)
 
+    const modalSave = document.createElement("span")
+    modalSave.className = "modalSave cursor"
+    modalSave.id = "modal_save"
+    modalSave.innerHTML = "&#x1F5AB;"
+    modalSave.addEventListener("click", modalSaveImage, true)
+    modalSave.title = "Save Image(s)"
+    modalControls.appendChild(modalSave)
+
     const modalClose = document.createElement('span')
     modalClose.className = 'modalClose cursor';
     modalClose.innerHTML = '&times;'
diff --git a/javascript/localization.js b/javascript/localization.js
index e6644635..f92d2d24 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -108,6 +108,9 @@ function processNode(node){
 
 function dumpTranslations(){
     dumped = {}
+    if (localization.rtl) {
+        dumped.rtl = true
+    }
 
     Object.keys(original_lines).forEach(function(text){
         if(dumped[text] !== undefined)  return
@@ -129,6 +132,24 @@ onUiUpdate(function(m){
 
 document.addEventListener("DOMContentLoaded", function() {
     processNode(gradioApp())
+
+    if (localization.rtl) {  // if the language is from right to left,
+        (new MutationObserver((mutations, observer) => { // wait for the style to load
+            mutations.forEach(mutation => {
+                mutation.addedNodes.forEach(node => {
+                    if (node.tagName === 'STYLE') {
+                        observer.disconnect();
+
+                        for (const x of node.sheet.rules) {  // find all rtl media rules
+                            if (Array.from(x.media || []).includes('rtl')) {
+                                x.media.appendMedium('all');  // enable them
+                            }
+                        }
+                    }
+                })
+            });
+        })).observe(gradioApp(), { childList: true });
+    }
 })
 
 function download_localization() {
diff --git a/javascript/notification.js b/javascript/notification.js
index f96de313..040a3afa 100644
--- a/javascript/notification.js
+++ b/javascript/notification.js
@@ -15,7 +15,7 @@ onUiUpdate(function(){
         }
     }
 
-    const galleryPreviews = gradioApp().querySelectorAll('img.h-full.w-full.overflow-hidden');
+    const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
 
     if (galleryPreviews == null) return;
 
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 24ab4795..d6323ed9 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -3,57 +3,75 @@ global_progressbars = {}
 galleries = {}
 galleryObservers = {}
 
+// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
+timeoutIds = {}
+
 function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
-    var progressbar = gradioApp().getElementById(id_progressbar)
+    // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
+    // every time you use gr.HTML(elem_id='xxx'), so we handle this here
+    var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
+    var progressbarParent
+    if(progressbar){
+        progressbarParent = gradioApp().querySelector("#"+id_progressbar)
+    } else{
+        progressbar = gradioApp().getElementById(id_progressbar)
+        progressbarParent = null
+    }
+
     var skip = id_skip ? gradioApp().getElementById(id_skip) : null
     var interrupt = gradioApp().getElementById(id_interrupt)
-    
+
     if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
         if(progressbar.innerText){
-            let newtitle = 'Stable Diffusion - ' + progressbar.innerText.slice(2)
+            let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
             if(document.title != newtitle){
-                document.title =  newtitle;          
+                document.title =  newtitle;
             }
         }else{
             let newtitle = 'Stable Diffusion'
             if(document.title != newtitle){
-                document.title =  newtitle;          
+                document.title =  newtitle;
             }
         }
     }
-    
+
 	if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
 	    global_progressbars[id_progressbar] = progressbar
 
         var mutationObserver = new MutationObserver(function(m){
+            if(timeoutIds[id_part]) return;
+
             preview = gradioApp().getElementById(id_preview)
             gallery = gradioApp().getElementById(id_gallery)
 
             if(preview != null && gallery != null){
                 preview.style.width = gallery.clientWidth + "px"
                 preview.style.height = gallery.clientHeight + "px"
+                if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
 
 				//only watch gallery if there is a generation process going on
                 check_gallery(id_gallery);
 
                 var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
-                if(!progressDiv){
+                if(progressDiv){
+                    timeoutIds[id_part] = window.setTimeout(function() {
+                        timeoutIds[id_part] = null
+                        requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
+                    }, 500)
+                } else{
                     if (skip) {
                         skip.style.display = "none"
                     }
                     interrupt.style.display = "none"
-			
+
                     //disconnect observer once generation finished, so user can close selected image if they want
                     if (galleryObservers[id_gallery]) {
                         galleryObservers[id_gallery].disconnect();
                         galleries[id_gallery] = null;
-                    }    
+                    }
                 }
-
-
             }
 
-            window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
         });
         mutationObserver.observe( progressbar, { childList:true, subtree:true })
 	}
@@ -74,14 +92,26 @@ function check_gallery(id_gallery){
             if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
                 // automatically re-open previously selected index (if exists)
                 activeElement = gradioApp().activeElement;
+                let scrollX = window.scrollX;
+                let scrollY = window.scrollY;
 
                 galleryButtons[prevSelectedIndex].click();
                 showGalleryImage();
 
+                // When the gallery button is clicked, it gains focus and scrolls itself into view
+                // We need to scroll back to the previous position
+                setTimeout(function (){
+                    window.scrollTo(scrollX, scrollY);
+                }, 50);
+
                 if(activeElement){
                     // i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
-                    // if somenoe has a better solution please by all means
-                    setTimeout(function() { activeElement.focus() }, 1);
+                    // if someone has a better solution please by all means
+                    setTimeout(function (){
+                        activeElement.focus({
+                            preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it
+                        })
+                    }, 1);
                 }
             }
         })
diff --git a/javascript/ui.js b/javascript/ui.js
index cfd0dcd3..ee226927 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -1,4 +1,4 @@
-// various functions for interation with ui.py not large enough to warrant putting them in separate files
+// various functions for interaction with ui.py not large enough to warrant putting them in separate files
 
 function set_theme(theme){
     gradioURL = window.location.href
@@ -8,8 +8,8 @@ function set_theme(theme){
 }
 
 function selected_gallery_index(){
-    var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item')
-    var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2')
+    var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item')
+    var button = gradioApp().querySelector('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item.\\!ring-2')
 
     var result = -1
     buttons.forEach(function(v, i){ if(v==button) { result = i } })
@@ -19,7 +19,7 @@ function selected_gallery_index(){
 
 function extract_image_from_gallery(gallery){
     if(gallery.length == 1){
-        return gallery[0]
+        return [gallery[0]]
     }
 
     index = selected_gallery_index()
@@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){
         return [null]
     }
 
-    return gallery[index];
+    return [gallery[index]];
 }
 
 function args_to_array(args){
@@ -45,14 +45,14 @@ function switch_to_txt2img(){
     return args_to_array(arguments);
 }
 
-function switch_to_img2img_img2img(){
+function switch_to_img2img(){
     gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
     gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
 
     return args_to_array(arguments);
 }
 
-function switch_to_img2img_inpaint(){
+function switch_to_inpaint(){
     gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
     gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
 
@@ -65,26 +65,6 @@ function switch_to_extras(){
     return args_to_array(arguments);
 }
 
-function extract_image_from_gallery_txt2img(gallery){
-    switch_to_txt2img()
-    return extract_image_from_gallery(gallery);
-}
-
-function extract_image_from_gallery_img2img(gallery){
-    switch_to_img2img_img2img()
-    return extract_image_from_gallery(gallery);
-}
-
-function extract_image_from_gallery_inpaint(gallery){
-    switch_to_img2img_inpaint()
-    return extract_image_from_gallery(gallery);
-}
-
-function extract_image_from_gallery_extras(gallery){
-    switch_to_extras()
-    return extract_image_from_gallery(gallery);
-}
-
 function get_tab_index(tabId){
     var res = 0
 
@@ -120,7 +100,7 @@ function create_submit_args(args){
 
     // As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
     // This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
-    // I don't know why gradio is seding outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
+    // I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
     // If gradio at some point stops sending outputs, this may break something
     if(Array.isArray(res[res.length - 3])){
         res[res.length - 3] = null
@@ -151,6 +131,15 @@ function ask_for_style_name(_, prompt_text, negative_prompt_text) {
     return [name_, prompt_text, negative_prompt_text]
 }
 
+function confirm_clear_prompt(prompt, negative_prompt) {
+    if(confirm("Delete prompt?")) {
+        prompt = ""
+        negative_prompt = ""
+    }
+
+    return [prompt, negative_prompt]
+}
+
 
 
 opts = {}
@@ -199,6 +188,17 @@ onUiUpdate(function(){
 		img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
 		img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
 	}
+
+    show_all_pages = gradioApp().getElementById('settings_show_all_pages')
+    settings_tabs = gradioApp().querySelector('#settings div')
+    if(show_all_pages && settings_tabs){
+        settings_tabs.appendChild(show_all_pages)
+        show_all_pages.onclick = function(){
+            gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
+                elem.style.display = "block";
+            })
+        }
+    }
 })
 
 let txt2img_textarea, img2img_textarea = undefined;
@@ -228,4 +228,6 @@ function update_token_counter(button_id) {
 function restart_reload(){
     document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
     setTimeout(function(){location.reload()},2000)
+
+    return []
 }
diff --git a/launch.py b/launch.py
index 7b15e78e..af0d418b 100644
--- a/launch.py
+++ b/launch.py
@@ -5,8 +5,11 @@ import sys
 import importlib.util
 import shlex
 import platform
+import argparse
+import json
 
 dir_repos = "repositories"
+dir_extensions = "extensions"
 python = sys.executable
 git = os.environ.get('GIT', "git")
 index_url = os.environ.get('INDEX_URL', "")
@@ -16,11 +19,24 @@ def extract_arg(args, name):
     return [x for x in args if x != name], name in args
 
 
-def run(command, desc=None, errdesc=None):
+def extract_opt(args, name):
+    opt = None
+    is_present = False
+    if name in args:
+        is_present = True
+        idx = args.index(name)
+        del args[idx]
+        if idx < len(args) and args[idx][0] != "-":
+            opt = args[idx]
+            del args[idx]
+    return args, is_present, opt
+
+
+def run(command, desc=None, errdesc=None, custom_env=None):
     if desc is not None:
         print(desc)
 
-    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
+    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
 
     if result.returncode != 0:
 
@@ -101,39 +117,81 @@ def version_check(commit):
         else:
             print("Not a git clone, can't perform version check.")
     except Exception as e:
-        print("versipm check failed",e)
+        print("version check failed", e)
 
-        
-def prepare_enviroment():
+
+def run_extension_installer(extension_dir):
+    path_installer = os.path.join(extension_dir, "install.py")
+    if not os.path.isfile(path_installer):
+        return
+
+    try:
+        env = os.environ.copy()
+        env['PYTHONPATH'] = os.path.abspath(".")
+
+        print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
+    except Exception as e:
+        print(e, file=sys.stderr)
+
+
+def list_extensions(settings_file):
+    settings = {}
+
+    try:
+        if os.path.isfile(settings_file):
+            with open(settings_file, "r", encoding="utf8") as file:
+                settings = json.load(file)
+    except Exception as e:
+        print(e, file=sys.stderr)
+
+    disabled_extensions = set(settings.get('disabled_extensions', []))
+
+    return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
+
+
+def run_extensions_installers(settings_file):
+    if not os.path.isdir(dir_extensions):
+        return
+
+    for dirname_extension in list_extensions(settings_file):
+        run_extension_installer(os.path.join(dir_extensions, dirname_extension))
+
+
+def prepare_environment():
     torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
     requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
     commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
 
     gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
     clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
-    deepdanbooru_package = os.environ.get('DEEPDANBOORU_PACKAGE', "git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26")
+    openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
 
     xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
 
-    stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git")
-    taming_transformers_repo = os.environ.get('TAMING_REANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
+    stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
+    taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
     k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
-    codeformer_repo = os.environ.get('CODEFORMET_REPO', 'https://github.com/sczhou/CodeFormer.git')
+    codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
     blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
 
-    stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
+    stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
     taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
-    k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878")
+    k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
     codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
     blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
 
     sys.argv += shlex.split(commandline_args)
 
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
+    args, _ = parser.parse_known_args(sys.argv)
+
+    sys.argv, _ = extract_arg(sys.argv, '-f')
     sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
     sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
     sys.argv, update_check = extract_arg(sys.argv, '--update-check')
+    sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
     xformers = '--xformers' in sys.argv
-    deepdanbooru = '--deepdanbooru' in sys.argv
     ngrok = '--ngrok' in sys.argv
 
     try:
@@ -156,21 +214,27 @@ def prepare_enviroment():
     if not is_installed("clip"):
         run_pip(f"install {clip_package}", "clip")
 
-    if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"):
+    if not is_installed("open_clip"):
+        run_pip(f"install {openclip_package}", "open_clip")
+
+    if (not is_installed("xformers") or reinstall_xformers) and xformers:
         if platform.system() == "Windows":
-            run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
+            if platform.python_version().startswith("3.10"):
+                run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
+            else:
+                print("Installation of xformers is not supported in this version of Python.")
+                print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
+                if not is_installed("xformers"):
+                    exit(0)
         elif platform.system() == "Linux":
             run_pip("install xformers", "xformers")
 
-    if not is_installed("deepdanbooru") and deepdanbooru:
-        run_pip(f"install {deepdanbooru_package}#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
-
     if not is_installed("pyngrok") and ngrok:
         run_pip("install pyngrok", "ngrok")
 
     os.makedirs(dir_repos, exist_ok=True)
 
-    git_clone(stable_diffusion_repo, repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
+    git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
     git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
     git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
     git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
@@ -181,6 +245,8 @@ def prepare_enviroment():
 
     run_pip(f"install -r {requirements_file}", "requirements for Web UI")
 
+    run_extensions_installers(settings_file=args.ui_settings_file)
+
     if update_check:
         version_check(commit)
     
@@ -188,13 +254,42 @@ def prepare_enviroment():
         print("Exiting because of --exit argument")
         exit(0)
 
+    if run_tests:
+        exitcode = tests(test_dir)
+        exit(exitcode)
 
-def start_webui():
-    print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
+
+def tests(test_dir):
+    if "--api" not in sys.argv:
+        sys.argv.append("--api")
+    if "--ckpt" not in sys.argv:
+        sys.argv.append("--ckpt")
+        sys.argv.append("./test/test_files/empty.pt")
+    if "--skip-torch-cuda-test" not in sys.argv:
+        sys.argv.append("--skip-torch-cuda-test")
+
+    print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
+
+    with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
+        proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
+
+    import test.server_poll
+    exitcode = test.server_poll.run_tests(proc, test_dir)
+
+    print(f"Stopping Web UI process with id {proc.pid}")
+    proc.kill()
+    return exitcode
+
+
+def start():
+    print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
     import webui
-    webui.webui()
+    if '--nowebui' in sys.argv:
+        webui.api_only()
+    else:
+        webui.webui()
 
 
 if __name__ == "__main__":
-    prepare_enviroment()
-    start_webui()
+    prepare_environment()
+    start()
diff --git a/models/VAE-approx/model.pt b/models/VAE-approx/model.pt
new file mode 100644
index 00000000..8bda9d6e
Binary files /dev/null and b/models/VAE-approx/model.pt differ
diff --git a/models/VAE/Put VAE here.txt b/models/VAE/Put VAE here.txt
new file mode 100644
index 00000000..e69de29b
diff --git a/modules/api/api.py b/modules/api/api.py
index 5b0c934e..48a70a44 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,67 +1,461 @@
-from modules.api.processing import StableDiffusionProcessingAPI
-from modules.processing import StableDiffusionProcessingTxt2Img, process_images
-from modules.sd_samplers import all_samplers
-from modules.extras import run_pnginfo
-import modules.shared as shared
-import uvicorn
-from fastapi import Body, APIRouter, HTTPException
-from fastapi.responses import JSONResponse
-from pydantic import BaseModel, Field, Json
-import json
-import io
 import base64
+import io
+import time
+import datetime
+import uvicorn
+from threading import Lock
+from io import BytesIO
+from gradio.processing_utils import decode_base64_to_file
+from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
+from fastapi.security import HTTPBasic, HTTPBasicCredentials
+from secrets import compare_digest
 
-sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+import modules.shared as shared
+from modules import sd_samplers, deepbooru, sd_hijack
+from modules.api.models import *
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.extras import run_extras, run_pnginfo
+from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
+from modules.textual_inversion.preprocess import preprocess
+from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
+from PIL import PngImagePlugin,Image
+from modules.sd_models import checkpoints_list, find_checkpoint_config
+from modules.realesrgan_model import get_realesrgan_models
+from modules import devices
+from typing import List
 
-class TextToImageResponse(BaseModel):
-    images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
-    parameters: Json
-    info: Json
+def upscaler_to_index(name: str):
+    try:
+        return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
+    except:
+        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
+
+
+def validate_sampler_name(name):
+    config = sd_samplers.all_samplers_map.get(name, None)
+    if config is None:
+        raise HTTPException(status_code=404, detail="Sampler not found")
+
+    return name
+
+def setUpscalers(req: dict):
+    reqDict = vars(req)
+    reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
+    reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
+    reqDict.pop('upscaler_1')
+    reqDict.pop('upscaler_2')
+    return reqDict
+
+def decode_base64_to_image(encoding):
+    if encoding.startswith("data:image/"):
+        encoding = encoding.split(";")[1].split(",")[1]
+    return Image.open(BytesIO(base64.b64decode(encoding)))
+
+def encode_pil_to_base64(image):
+    with io.BytesIO() as output_bytes:
+
+        # Copy any text-only metadata
+        use_metadata = False
+        metadata = PngImagePlugin.PngInfo()
+        for key, value in image.info.items():
+            if isinstance(key, str) and isinstance(value, str):
+                metadata.add_text(key, value)
+                use_metadata = True
+
+        image.save(
+            output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
+        )
+        bytes_data = output_bytes.getvalue()
+    return base64.b64encode(bytes_data)
+
+def api_middleware(app: FastAPI):
+    @app.middleware("http")
+    async def log_and_time(req: Request, call_next):
+        ts = time.time()
+        res: Response = await call_next(req)
+        duration = str(round(time.time() - ts, 4))
+        res.headers["X-Process-Time"] = duration
+        endpoint = req.scope.get('path', 'err')
+        if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
+            print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
+                t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
+                code = res.status_code,
+                ver = req.scope.get('http_version', '0.0'),
+                cli = req.scope.get('client', ('0:0.0.0', 0))[0],
+                prot = req.scope.get('scheme', 'err'),
+                method = req.scope.get('method', 'err'),
+                endpoint = endpoint,
+                duration = duration,
+            ))
+        return res
 
 
 class Api:
-    def __init__(self, app, queue_lock):
+    def __init__(self, app: FastAPI, queue_lock: Lock):
+        if shared.cmd_opts.api_auth:
+            self.credentials = dict()
+            for auth in shared.cmd_opts.api_auth.split(","):
+                user, password = auth.split(":")
+                self.credentials[user] = password
+
         self.router = APIRouter()
         self.app = app
         self.queue_lock = queue_lock
-        self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
+        api_middleware(self.app)
+        self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
+        self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
+        self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
+        self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
+        self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
+        self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
+        self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
+        self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
+        self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
+        self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
+        self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
+        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
+        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
+        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
+        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
+        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
+        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
+        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
+        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
+        self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
+        self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
+        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
+        self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
+        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
+        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
+        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
+        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
+        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
 
-    def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
-        sampler_index = sampler_to_index(txt2imgreq.sampler_index)
-        
-        if sampler_index is None:
-            raise HTTPException(status_code=404, detail="Sampler not found") 
-        
+    def add_api_route(self, path: str, endpoint, **kwargs):
+        if shared.cmd_opts.api_auth:
+            return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
+        return self.app.add_api_route(path, endpoint, **kwargs)
+
+    def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
+        if credentials.username in self.credentials:
+            if compare_digest(credentials.password, self.credentials[credentials.username]):
+                return True
+
+        raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
+
+    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
         populate = txt2imgreq.copy(update={ # Override __init__ params
-            "sd_model": shared.sd_model, 
-            "sampler_index": sampler_index[0],
+            "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
             "do_not_save_samples": True,
             "do_not_save_grid": True
             }
         )
-        p = StableDiffusionProcessingTxt2Img(**vars(populate))
+        if populate.sampler_name:
+            populate.sampler_index = None  # prevent a warning later on
+
+        with self.queue_lock:
+            p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
+
+            shared.state.begin()
+            processed = process_images(p)
+            shared.state.end()
+
+
+        b64images = list(map(encode_pil_to_base64, processed.images))
+
+        return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
+
+    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
+        init_images = img2imgreq.init_images
+        if init_images is None:
+            raise HTTPException(status_code=404, detail="Init image not found")
+
+        mask = img2imgreq.mask
+        if mask:
+            mask = decode_base64_to_image(mask)
+
+        populate = img2imgreq.copy(update={ # Override __init__ params
+            "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
+            "do_not_save_samples": True,
+            "do_not_save_grid": True,
+            "mask": mask
+            }
+        )
+        if populate.sampler_name:
+            populate.sampler_index = None  # prevent a warning later on
+
+        args = vars(populate)
+        args.pop('include_init_images', None)  # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
+
+        with self.queue_lock:
+            p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
+            p.init_images = [decode_base64_to_image(x) for x in init_images]
+
+            shared.state.begin()
+            processed = process_images(p)
+            shared.state.end()
+
+        b64images = list(map(encode_pil_to_base64, processed.images))
+
+        if not img2imgreq.include_init_images:
+            img2imgreq.init_images = None
+            img2imgreq.mask = None
+
+        return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
+
+    def extras_single_image_api(self, req: ExtrasSingleImageRequest):
+        reqDict = setUpscalers(req)
+
+        reqDict['image'] = decode_base64_to_image(reqDict['image'])
+
+        with self.queue_lock:
+            result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
+
+        return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
+
+    def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
+        reqDict = setUpscalers(req)
+
+        def prepareFiles(file):
+            file = decode_base64_to_file(file.data, file_path=file.name)
+            file.orig_name = file.name
+            return file
+
+        reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
+        reqDict.pop('imageList')
+
+        with self.queue_lock:
+            result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
+
+        return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
+
+    def pnginfoapi(self, req: PNGInfoRequest):
+        if(not req.image.strip()):
+            return PNGInfoResponse(info="")
+
+        result = run_pnginfo(decode_base64_to_image(req.image.strip()))
+
+        return PNGInfoResponse(info=result[1])
+
+    def progressapi(self, req: ProgressRequest = Depends()):
+        # copy from check_progress_call of ui.py
+
+        if shared.state.job_count == 0:
+            return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
+
+        # avoid dividing zero
+        progress = 0.01
+
+        if shared.state.job_count > 0:
+            progress += shared.state.job_no / shared.state.job_count
+        if shared.state.sampling_steps > 0:
+            progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+
+        time_since_start = time.time() - shared.state.time_start
+        eta = (time_since_start/progress)
+        eta_relative = eta-time_since_start
+
+        progress = min(progress, 1)
+
+        shared.state.set_current_image()
+
+        current_image = None
+        if shared.state.current_image and not req.skip_current_image:
+            current_image = encode_pil_to_base64(shared.state.current_image)
+
+        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
+
+    def interrogateapi(self, interrogatereq: InterrogateRequest):
+        image_b64 = interrogatereq.image
+        if image_b64 is None:
+            raise HTTPException(status_code=404, detail="Image not found")
+
+        img = decode_base64_to_image(image_b64)
+        img = img.convert('RGB')
+
         # Override object param
         with self.queue_lock:
-            processed = process_images(p)
-        
-        b64images = []
-        for i in processed.images:
-            buffer = io.BytesIO()
-            i.save(buffer, format="png")
-            b64images.append(base64.b64encode(buffer.getvalue()))
+            if interrogatereq.model == "clip":
+                processed = shared.interrogator.interrogate(img)
+            elif interrogatereq.model == "deepdanbooru":
+                processed = deepbooru.model.tag(img)
+            else:
+                raise HTTPException(status_code=404, detail="Model not found")
 
-        return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
-        
-        
+        return InterrogateResponse(caption=processed)
 
-    def img2imgapi(self):
-        raise NotImplementedError
+    def interruptapi(self):
+        shared.state.interrupt()
 
-    def extrasapi(self):
-        raise NotImplementedError
+        return {}
 
-    def pnginfoapi(self):
-        raise NotImplementedError
+    def skip(self):
+        shared.state.skip()
+
+    def get_config(self):
+        options = {}
+        for key in shared.opts.data.keys():
+            metadata = shared.opts.data_labels.get(key)
+            if(metadata is not None):
+                options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
+            else:
+                options.update({key: shared.opts.data.get(key, None)})
+
+        return options
+
+    def set_config(self, req: Dict[str, Any]):
+        for k, v in req.items():
+            shared.opts.set(k, v)
+
+        shared.opts.save(shared.config_filename)
+        return
+
+    def get_cmd_flags(self):
+        return vars(shared.cmd_opts)
+
+    def get_samplers(self):
+        return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
+
+    def get_upscalers(self):
+        upscalers = []
+
+        for upscaler in shared.sd_upscalers:
+            u = upscaler.scaler
+            upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
+
+        return upscalers
+
+    def get_sd_models(self):
+        return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
+
+    def get_hypernetworks(self):
+        return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
+
+    def get_face_restorers(self):
+        return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
+
+    def get_realesrgan_models(self):
+        return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
+
+    def get_prompt_styles(self):
+        styleList = []
+        for k in shared.prompt_styles.styles:
+            style = shared.prompt_styles.styles[k]
+            styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
+
+        return styleList
+
+    def get_artists_categories(self):
+        return shared.artist_db.cats
+
+    def get_artists(self):
+        return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
+
+    def get_embeddings(self):
+        db = sd_hijack.model_hijack.embedding_db
+
+        def convert_embedding(embedding):
+            return {
+                "step": embedding.step,
+                "sd_checkpoint": embedding.sd_checkpoint,
+                "sd_checkpoint_name": embedding.sd_checkpoint_name,
+                "shape": embedding.shape,
+                "vectors": embedding.vectors,
+            }
+
+        def convert_embeddings(embeddings):
+            return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
+
+        return {
+            "loaded": convert_embeddings(db.word_embeddings),
+            "skipped": convert_embeddings(db.skipped_embeddings),
+        }
+
+    def refresh_checkpoints(self):
+        shared.refresh_checkpoints()
+
+    def create_embedding(self, args: dict):
+        try:
+            shared.state.begin()
+            filename = create_embedding(**args) # create empty embedding
+            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
+            shared.state.end()
+            return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
+        except AssertionError as e:
+            shared.state.end()
+            return TrainResponse(info = "create embedding error: {error}".format(error = e))
+
+    def create_hypernetwork(self, args: dict):
+        try:
+            shared.state.begin()
+            filename = create_hypernetwork(**args) # create empty embedding
+            shared.state.end()
+            return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
+        except AssertionError as e:
+            shared.state.end()
+            return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
+
+    def preprocess(self, args: dict):
+        try:
+            shared.state.begin()
+            preprocess(**args) # quick operation unless blip/booru interrogation is enabled
+            shared.state.end()
+            return PreprocessResponse(info = 'preprocess complete')
+        except KeyError as e:
+            shared.state.end()
+            return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
+        except AssertionError as e:
+            shared.state.end()
+            return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
+        except FileNotFoundError as e:
+            shared.state.end()
+            return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
+
+    def train_embedding(self, args: dict):
+        try:
+            shared.state.begin()
+            apply_optimizations = shared.opts.training_xattention_optimizations
+            error = None
+            filename = ''
+            if not apply_optimizations:
+                sd_hijack.undo_optimizations()
+            try:
+                embedding, filename = train_embedding(**args) # can take a long time to complete
+            except Exception as e:
+                error = e
+            finally:
+                if not apply_optimizations:
+                    sd_hijack.apply_optimizations()
+                shared.state.end()
+            return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+        except AssertionError as msg:
+            shared.state.end()
+            return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
+
+    def train_hypernetwork(self, args: dict):
+        try:
+            shared.state.begin()
+            initial_hypernetwork = shared.loaded_hypernetwork
+            apply_optimizations = shared.opts.training_xattention_optimizations
+            error = None
+            filename = ''
+            if not apply_optimizations:
+                sd_hijack.undo_optimizations()
+            try:
+                hypernetwork, filename = train_hypernetwork(*args)
+            except Exception as e:
+                error = e
+            finally:
+                shared.loaded_hypernetwork = initial_hypernetwork
+                shared.sd_model.cond_stage_model.to(devices.device)
+                shared.sd_model.first_stage_model.to(devices.device)
+                if not apply_optimizations:
+                    sd_hijack.apply_optimizations()
+                shared.state.end()
+            return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+        except AssertionError as msg:
+            shared.state.end()
+            return TrainResponse(info = "train embedding error: {error}".format(error = error))
 
     def launch(self, server_name, port):
         self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
new file mode 100644
index 00000000..4a632c68
--- /dev/null
+++ b/modules/api/models.py
@@ -0,0 +1,261 @@
+import inspect
+from pydantic import BaseModel, Field, create_model
+from typing import Any, Optional
+from typing_extensions import Literal
+from inflection import underscore
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
+from modules.shared import sd_upscalers, opts, parser
+from typing import Dict, List
+
+API_NOT_ALLOWED = [
+    "self",
+    "kwargs",
+    "sd_model",
+    "outpath_samples",
+    "outpath_grids",
+    "sampler_index",
+    "do_not_save_samples",
+    "do_not_save_grid",
+    "extra_generation_params",
+    "overlay_images",
+    "do_not_reload_embeddings",
+    "seed_enable_extras",
+    "prompt_for_display",
+    "sampler_noise_scheduler_override",
+    "ddim_discretize"
+]
+
+class ModelDef(BaseModel):
+    """Assistance Class for Pydantic Dynamic Model Generation"""
+
+    field: str
+    field_alias: str
+    field_type: Any
+    field_value: Any
+    field_exclude: bool = False
+
+
+class PydanticModelGenerator:
+    """
+    Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
+    source_data is a snapshot of the default values produced by the class
+    params are the names of the actual keys required by __init__
+    """
+
+    def __init__(
+        self,
+        model_name: str = None,
+        class_instance = None,
+        additional_fields = None,
+    ):
+        def field_type_generator(k, v):
+            # field_type = str if not overrides.get(k) else overrides[k]["type"]
+            # print(k, v.annotation, v.default)
+            field_type = v.annotation
+
+            return Optional[field_type]
+
+        def merge_class_params(class_):
+            all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
+            parameters = {}
+            for classes in all_classes:
+                parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
+            return parameters
+
+
+        self._model_name = model_name
+        self._class_data = merge_class_params(class_instance)
+
+        self._model_def = [
+            ModelDef(
+                field=underscore(k),
+                field_alias=k,
+                field_type=field_type_generator(k, v),
+                field_value=v.default
+            )
+            for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
+        ]
+
+        for fields in additional_fields:
+            self._model_def.append(ModelDef(
+                field=underscore(fields["key"]),
+                field_alias=fields["key"],
+                field_type=fields["type"],
+                field_value=fields["default"],
+                field_exclude=fields["exclude"] if "exclude" in fields else False))
+
+    def generate_model(self):
+        """
+        Creates a pydantic BaseModel
+        from the json and overrides provided at initialization
+        """
+        fields = {
+            d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
+        }
+        DynamicModel = create_model(self._model_name, **fields)
+        DynamicModel.__config__.allow_population_by_field_name = True
+        DynamicModel.__config__.allow_mutation = True
+        return DynamicModel
+
+StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
+    "StableDiffusionProcessingTxt2Img",
+    StableDiffusionProcessingTxt2Img,
+    [{"key": "sampler_index", "type": str, "default": "Euler"}]
+).generate_model()
+
+StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
+    "StableDiffusionProcessingImg2Img",
+    StableDiffusionProcessingImg2Img,
+    [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
+).generate_model()
+
+class TextToImageResponse(BaseModel):
+    images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+    parameters: dict
+    info: str
+
+class ImageToImageResponse(BaseModel):
+    images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+    parameters: dict
+    info: str
+
+class ExtrasBaseRequest(BaseModel):
+    resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
+    show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
+    gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
+    codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
+    codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
+    upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
+    upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
+    upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
+    upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
+    upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+    upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+    extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
+    upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
+
+class ExtraBaseResponse(BaseModel):
+    html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
+
+class ExtrasSingleImageRequest(ExtrasBaseRequest):
+    image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+
+class ExtrasSingleImageResponse(ExtraBaseResponse):
+    image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
+
+class FileData(BaseModel):
+    data: str = Field(title="File data", description="Base64 representation of the file")
+    name: str = Field(title="File name")
+
+class ExtrasBatchImagesRequest(ExtrasBaseRequest):
+    imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+
+class ExtrasBatchImagesResponse(ExtraBaseResponse):
+    images: List[str] = Field(title="Images", description="The generated images in base64 format.")
+
+class PNGInfoRequest(BaseModel):
+    image: str = Field(title="Image", description="The base64 encoded PNG image")
+
+class PNGInfoResponse(BaseModel):
+    info: str = Field(title="Image info", description="A string with all the info the image had")
+
+class ProgressRequest(BaseModel):
+    skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
+
+class ProgressResponse(BaseModel):
+    progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
+    eta_relative: float = Field(title="ETA in secs")
+    state: dict = Field(title="State", description="The current state snapshot")
+    current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
+
+class InterrogateRequest(BaseModel):
+    image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+    model: str = Field(default="clip", title="Model", description="The interrogate model used.")
+
+class InterrogateResponse(BaseModel):
+    caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
+
+class TrainResponse(BaseModel):
+    info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
+
+class CreateResponse(BaseModel):
+    info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
+
+class PreprocessResponse(BaseModel):
+    info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
+
+fields = {}
+for key, metadata in opts.data_labels.items():
+    value = opts.data.get(key)
+    optType = opts.typemap.get(type(metadata.default), type(value))
+
+    if (metadata is not None):
+        fields.update({key: (Optional[optType], Field(
+            default=metadata.default ,description=metadata.label))})
+    else:
+        fields.update({key: (Optional[optType], Field())})
+
+OptionsModel = create_model("Options", **fields)
+
+flags = {}
+_options = vars(parser)['_option_string_actions']
+for key in _options:
+    if(_options[key].dest != 'help'):
+        flag = _options[key]
+        _type = str
+        if _options[key].default is not None: _type = type(_options[key].default)
+        flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
+
+FlagsModel = create_model("Flags", **flags)
+
+class SamplerItem(BaseModel):
+    name: str = Field(title="Name")
+    aliases: List[str] = Field(title="Aliases")
+    options: Dict[str, str] = Field(title="Options")
+
+class UpscalerItem(BaseModel):
+    name: str = Field(title="Name")
+    model_name: Optional[str] = Field(title="Model Name")
+    model_path: Optional[str] = Field(title="Path")
+    model_url: Optional[str] = Field(title="URL")
+
+class SDModelItem(BaseModel):
+    title: str = Field(title="Title")
+    model_name: str = Field(title="Model Name")
+    hash: str = Field(title="Hash")
+    filename: str = Field(title="Filename")
+    config: str = Field(title="Config file")
+
+class HypernetworkItem(BaseModel):
+    name: str = Field(title="Name")
+    path: Optional[str] = Field(title="Path")
+
+class FaceRestorerItem(BaseModel):
+    name: str = Field(title="Name")
+    cmd_dir: Optional[str] = Field(title="Path")
+
+class RealesrganItem(BaseModel):
+    name: str = Field(title="Name")
+    path: Optional[str] = Field(title="Path")
+    scale: Optional[int] = Field(title="Scale")
+
+class PromptStyleItem(BaseModel):
+    name: str = Field(title="Name")
+    prompt: Optional[str] = Field(title="Prompt")
+    negative_prompt: Optional[str] = Field(title="Negative Prompt")
+
+class ArtistItem(BaseModel):
+    name: str = Field(title="Name")
+    score: float = Field(title="Score")
+    category: str = Field(title="Category")
+
+class EmbeddingItem(BaseModel):
+    step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+    sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+    sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+    shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
+    vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+
+class EmbeddingsResponse(BaseModel):
+    loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+    skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
\ No newline at end of file
diff --git a/modules/api/processing.py b/modules/api/processing.py
deleted file mode 100644
index 4c541241..00000000
--- a/modules/api/processing.py
+++ /dev/null
@@ -1,99 +0,0 @@
-from inflection import underscore
-from typing import Any, Dict, Optional
-from pydantic import BaseModel, Field, create_model
-from modules.processing import StableDiffusionProcessingTxt2Img
-import inspect
-
-
-API_NOT_ALLOWED = [
-    "self",
-    "kwargs",
-    "sd_model",
-    "outpath_samples",
-    "outpath_grids",
-    "sampler_index",
-    "do_not_save_samples",
-    "do_not_save_grid",
-    "extra_generation_params",
-    "overlay_images",
-    "do_not_reload_embeddings",
-    "seed_enable_extras",
-    "prompt_for_display",
-    "sampler_noise_scheduler_override",
-    "ddim_discretize"
-]
-
-class ModelDef(BaseModel):
-    """Assistance Class for Pydantic Dynamic Model Generation"""
-
-    field: str
-    field_alias: str
-    field_type: Any
-    field_value: Any
-
-
-class PydanticModelGenerator:
-    """
-    Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
-    source_data is a snapshot of the default values produced by the class
-    params are the names of the actual keys required by __init__
-    """
-
-    def __init__(
-        self,
-        model_name: str = None,
-        class_instance = None,
-        additional_fields = None,
-    ):
-        def field_type_generator(k, v):
-            # field_type = str if not overrides.get(k) else overrides[k]["type"]
-            # print(k, v.annotation, v.default)
-            field_type = v.annotation
-            
-            return Optional[field_type]
-        
-        def merge_class_params(class_):
-            all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
-            parameters = {}
-            for classes in all_classes:
-                parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
-            return parameters
-            
-                
-        self._model_name = model_name
-        self._class_data = merge_class_params(class_instance)
-        self._model_def = [
-            ModelDef(
-                field=underscore(k),
-                field_alias=k,
-                field_type=field_type_generator(k, v),
-                field_value=v.default
-            )
-            for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
-        ]
-        
-        for fields in additional_fields:
-            self._model_def.append(ModelDef(
-                field=underscore(fields["key"]), 
-                field_alias=fields["key"], 
-                field_type=fields["type"],
-                field_value=fields["default"]))
-
-    def generate_model(self):
-        """
-        Creates a pydantic BaseModel
-        from the json and overrides provided at initialization
-        """
-        fields = {
-            d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
-        }
-        DynamicModel = create_model(self._model_name, **fields)
-        DynamicModel.__config__.allow_population_by_field_name = True
-        DynamicModel.__config__.allow_mutation = True
-        return DynamicModel
-    
-StableDiffusionProcessingAPI = PydanticModelGenerator(
-    "StableDiffusionProcessingTxt2Img", 
-    StableDiffusionProcessingTxt2Img,
-    [{"key": "sampler_index", "type": str, "default": "Euler"}]
-).generate_model()
\ No newline at end of file
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py
deleted file mode 100644
index 737e1a76..00000000
--- a/modules/bsrgan_model.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import os.path
-import sys
-import traceback
-
-import PIL.Image
-import numpy as np
-import torch
-from basicsr.utils.download_util import load_file_from_url
-
-import modules.upscaler
-from modules import devices, modelloader
-from modules.bsrgan_model_arch import RRDBNet
-
-
-class UpscalerBSRGAN(modules.upscaler.Upscaler):
-    def __init__(self, dirname):
-        self.name = "BSRGAN"
-        self.model_name = "BSRGAN 4x"
-        self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
-        self.user_path = dirname
-        super().__init__()
-        model_paths = self.find_models(ext_filter=[".pt", ".pth"])
-        scalers = []
-        if len(model_paths) == 0:
-            scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
-            scalers.append(scaler_data)
-        for file in model_paths:
-            if "http" in file:
-                name = self.model_name
-            else:
-                name = modelloader.friendly_name(file)
-            try:
-                scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
-                scalers.append(scaler_data)
-            except Exception:
-                print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
-                print(traceback.format_exc(), file=sys.stderr)
-        self.scalers = scalers
-
-    def do_upscale(self, img: PIL.Image, selected_file):
-        torch.cuda.empty_cache()
-        model = self.load_model(selected_file)
-        if model is None:
-            return img
-        model.to(devices.device_bsrgan)
-        torch.cuda.empty_cache()
-        img = np.array(img)
-        img = img[:, :, ::-1]
-        img = np.moveaxis(img, 2, 0) / 255
-        img = torch.from_numpy(img).float()
-        img = img.unsqueeze(0).to(devices.device_bsrgan)
-        with torch.no_grad():
-            output = model(img)
-        output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
-        output = 255. * np.moveaxis(output, 0, 2)
-        output = output.astype(np.uint8)
-        output = output[:, :, ::-1]
-        torch.cuda.empty_cache()
-        return PIL.Image.fromarray(output, 'RGB')
-
-    def load_model(self, path: str):
-        if "http" in path:
-            filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
-                                          progress=True)
-        else:
-            filename = path
-        if not os.path.exists(filename) or filename is None:
-            print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
-            return None
-        model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4)  # define network
-        model.load_state_dict(torch.load(filename), strict=True)
-        model.eval()
-        for k, v in model.named_parameters():
-            v.requires_grad = False
-        return model
-
diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py
deleted file mode 100644
index cb4d1c13..00000000
--- a/modules/bsrgan_model_arch.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import functools
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.nn.init as init
-
-
-def initialize_weights(net_l, scale=1):
-    if not isinstance(net_l, list):
-        net_l = [net_l]
-    for net in net_l:
-        for m in net.modules():
-            if isinstance(m, nn.Conv2d):
-                init.kaiming_normal_(m.weight, a=0, mode='fan_in')
-                m.weight.data *= scale  # for residual block
-                if m.bias is not None:
-                    m.bias.data.zero_()
-            elif isinstance(m, nn.Linear):
-                init.kaiming_normal_(m.weight, a=0, mode='fan_in')
-                m.weight.data *= scale
-                if m.bias is not None:
-                    m.bias.data.zero_()
-            elif isinstance(m, nn.BatchNorm2d):
-                init.constant_(m.weight, 1)
-                init.constant_(m.bias.data, 0.0)
-
-
-def make_layer(block, n_layers):
-    layers = []
-    for _ in range(n_layers):
-        layers.append(block())
-    return nn.Sequential(*layers)
-
-
-class ResidualDenseBlock_5C(nn.Module):
-    def __init__(self, nf=64, gc=32, bias=True):
-        super(ResidualDenseBlock_5C, self).__init__()
-        # gc: growth channel, i.e. intermediate channels
-        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
-        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
-        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
-        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
-        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
-        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
-        # initialization
-        initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
-
-    def forward(self, x):
-        x1 = self.lrelu(self.conv1(x))
-        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
-        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
-        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
-        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
-        return x5 * 0.2 + x
-
-
-class RRDB(nn.Module):
-    '''Residual in Residual Dense Block'''
-
-    def __init__(self, nf, gc=32):
-        super(RRDB, self).__init__()
-        self.RDB1 = ResidualDenseBlock_5C(nf, gc)
-        self.RDB2 = ResidualDenseBlock_5C(nf, gc)
-        self.RDB3 = ResidualDenseBlock_5C(nf, gc)
-
-    def forward(self, x):
-        out = self.RDB1(x)
-        out = self.RDB2(out)
-        out = self.RDB3(out)
-        return out * 0.2 + x
-
-
-class RRDBNet(nn.Module):
-    def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
-        super(RRDBNet, self).__init__()
-        RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
-        self.sf = sf
-
-        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
-        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
-        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
-        #### upsampling
-        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
-        if self.sf==4:
-            self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
-        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
-        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
-
-        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
-    def forward(self, x):
-        fea = self.conv_first(x)
-        trunk = self.trunk_conv(self.RRDB_trunk(fea))
-        fea = fea + trunk
-
-        fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
-        if self.sf==4:
-            fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
-        out = self.conv_last(self.lrelu(self.HRconv(fea)))
-
-        return out
\ No newline at end of file
diff --git a/modules/call_queue.py b/modules/call_queue.py
new file mode 100644
index 00000000..4cd49533
--- /dev/null
+++ b/modules/call_queue.py
@@ -0,0 +1,98 @@
+import html
+import sys
+import threading
+import traceback
+import time
+
+from modules import shared
+
+queue_lock = threading.Lock()
+
+
+def wrap_queued_call(func):
+    def f(*args, **kwargs):
+        with queue_lock:
+            res = func(*args, **kwargs)
+
+        return res
+
+    return f
+
+
+def wrap_gradio_gpu_call(func, extra_outputs=None):
+    def f(*args, **kwargs):
+
+        shared.state.begin()
+
+        with queue_lock:
+            res = func(*args, **kwargs)
+
+        shared.state.end()
+
+        return res
+
+    return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
+
+
+def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
+    def f(*args, extra_outputs_array=extra_outputs, **kwargs):
+        run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
+        if run_memmon:
+            shared.mem_mon.monitor()
+        t = time.perf_counter()
+
+        try:
+            res = list(func(*args, **kwargs))
+        except Exception as e:
+            # When printing out our debug argument list, do not print out more than a MB of text
+            max_debug_str_len = 131072 # (1024*1024)/8
+
+            print("Error completing request", file=sys.stderr)
+            argStr = f"Arguments: {str(args)} {str(kwargs)}"
+            print(argStr[:max_debug_str_len], file=sys.stderr)
+            if len(argStr) > max_debug_str_len:
+                print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
+
+            print(traceback.format_exc(), file=sys.stderr)
+
+            shared.state.job = ""
+            shared.state.job_count = 0
+
+            if extra_outputs_array is None:
+                extra_outputs_array = [None, '']
+
+            res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
+
+        shared.state.skipped = False
+        shared.state.interrupted = False
+        shared.state.job_count = 0
+
+        if not add_stats:
+            return tuple(res)
+
+        elapsed = time.perf_counter() - t
+        elapsed_m = int(elapsed // 60)
+        elapsed_s = elapsed % 60
+        elapsed_text = f"{elapsed_s:.2f}s"
+        if elapsed_m > 0:
+            elapsed_text = f"{elapsed_m}m "+elapsed_text
+
+        if run_memmon:
+            mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
+            active_peak = mem_stats['active_peak']
+            reserved_peak = mem_stats['reserved_peak']
+            sys_peak = mem_stats['system_peak']
+            sys_total = mem_stats['total']
+            sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
+
+            vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
+        else:
+            vram_html = ''
+
+        # last item is always HTML
+        res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
+
+        return tuple(res)
+
+    return f
+
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
index c06c590c..e7293683 100644
--- a/modules/codeformer/vqgan_arch.py
+++ b/modules/codeformer/vqgan_arch.py
@@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module):
                 self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
                 logger.info(f'vqgan is loaded from: {model_path} [params]')
             else:
-                raise ValueError(f'Wrong params!')
+                raise ValueError('Wrong params!')
 
 
     def forward(self, x):
@@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module):
             elif 'params' in chkpt:
                 self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
             else:
-                raise ValueError(f'Wrong params!')
+                raise ValueError('Wrong params!')
 
     def forward(self, x):
         return self.main(x)
\ No newline at end of file
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index e6d9fa4f..ab40d842 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -36,6 +36,7 @@ def setup_model(dirname):
         from basicsr.utils.download_util import load_file_from_url
         from basicsr.utils import imwrite, img2tensor, tensor2img
         from facelib.utils.face_restoration_helper import FaceRestoreHelper
+        from facelib.detection.retinaface import retinaface
         from modules.shared import cmd_opts
 
         net_class = CodeFormer
@@ -65,6 +66,8 @@ def setup_model(dirname):
                 net.load_state_dict(checkpoint)
                 net.eval()
 
+                if hasattr(retinaface, 'device'):
+                    retinaface.device = devices.device_codeformer
                 face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
 
                 self.net = net
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 8914662d..122fce7f 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -1,172 +1,99 @@
-import os.path
-from concurrent.futures import ProcessPoolExecutor
-import multiprocessing
-import time
+import os
 import re
 
+import torch
+from PIL import Image
+import numpy as np
+
+from modules import modelloader, paths, deepbooru_model, devices, images, shared
+
 re_special = re.compile(r'([\\()])')
 
-def get_deepbooru_tags(pil_image):
-    """
-    This method is for running only one image at a time for simple use.  Used to the img2img interrogate.
-    """
-    from modules import shared  # prevents circular reference
 
-    try:
-        create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
-        return get_tags_from_process(pil_image)
-    finally:
-        release_process()
+class DeepDanbooru:
+    def __init__(self):
+        self.model = None
 
+    def load(self):
+        if self.model is not None:
+            return
 
-OPT_INCLUDE_RANKS = "include_ranks"
-def create_deepbooru_opts():
-    from modules import shared
+        files = modelloader.load_models(
+            model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
+            model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
+            ext_filter=[".pt"],
+            download_name='model-resnet_custom_v3.pt',
+        )
 
-    return {
-        "use_spaces": shared.opts.deepbooru_use_spaces,
-        "use_escape": shared.opts.deepbooru_escape,
-        "alpha_sort": shared.opts.deepbooru_sort_alpha,
-        OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
-    }
+        self.model = deepbooru_model.DeepDanbooruModel()
+        self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
 
+        self.model.eval()
+        self.model.to(devices.cpu, devices.dtype)
 
-def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
-    model, tags = get_deepbooru_tags_model()
-    while True: # while process is running, keep monitoring queue for new image
-        pil_image = queue.get()
-        if pil_image == "QUIT":
-            break
-        else:
-            deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
+    def start(self):
+        self.load()
+        self.model.to(devices.device)
 
+    def stop(self):
+        if not shared.opts.interrogate_keep_models_in_memory:
+            self.model.to(devices.cpu)
+            devices.torch_gc()
 
-def create_deepbooru_process(threshold, deepbooru_opts):
-    """
-    Creates deepbooru process.  A queue is created to send images into the process.  This enables multiple images
-    to be processed in a row without reloading the model or creating a new process.  To return the data, a shared
-    dictionary is created to hold the tags created.  To wait for tags to be returned, a value of -1 is assigned
-    to the dictionary and the method adding the image to the queue should wait for this value to be updated with
-    the tags.
-    """
-    from modules import shared  # prevents circular reference
-    shared.deepbooru_process_manager = multiprocessing.Manager()
-    shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
-    shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
-    shared.deepbooru_process_return["value"] = -1
-    shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
-    shared.deepbooru_process.start()
+    def tag(self, pil_image):
+        self.start()
+        res = self.tag_multi(pil_image)
+        self.stop()
 
+        return res
 
-def get_tags_from_process(image):
-    from modules import shared
+    def tag_multi(self, pil_image, force_disable_ranks=False):
+        threshold = shared.opts.interrogate_deepbooru_score_threshold
+        use_spaces = shared.opts.deepbooru_use_spaces
+        use_escape = shared.opts.deepbooru_escape
+        alpha_sort = shared.opts.deepbooru_sort_alpha
+        include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
 
-    shared.deepbooru_process_return["value"] = -1
-    shared.deepbooru_process_queue.put(image)
-    while shared.deepbooru_process_return["value"] == -1:
-        time.sleep(0.2)
-    caption = shared.deepbooru_process_return["value"]
-    shared.deepbooru_process_return["value"] = -1
+        pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
+        a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
 
-    return caption
+        with torch.no_grad(), devices.autocast():
+            x = torch.from_numpy(a).to(devices.device)
+            y = self.model(x)[0].detach().cpu().numpy()
 
+        probability_dict = {}
 
-def release_process():
-    """
-    Stops the deepbooru process to return used memory
-    """
-    from modules import shared  # prevents circular reference
-    shared.deepbooru_process_queue.put("QUIT")
-    shared.deepbooru_process.join()
-    shared.deepbooru_process_queue = None
-    shared.deepbooru_process = None
-    shared.deepbooru_process_return = None
-    shared.deepbooru_process_manager = None
+        for tag, probability in zip(self.model.tags, y):
+            if probability < threshold:
+                continue
 
-def get_deepbooru_tags_model():
-    import deepdanbooru as dd
-    import tensorflow as tf
-    import numpy as np
-    this_folder = os.path.dirname(__file__)
-    model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
-    if not os.path.exists(os.path.join(model_path, 'project.json')):
-        # there is no point importing these every time
-        import zipfile
-        from basicsr.utils.download_util import load_file_from_url
-        load_file_from_url(
-            r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
-            model_path)
-        with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
-            zip_ref.extractall(model_path)
-        os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
-
-    tags = dd.project.load_tags_from_project(model_path)
-    model = dd.project.load_model_from_project(
-        model_path, compile_model=False
-    )
-    return model, tags
-
-
-def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
-    import deepdanbooru as dd
-    import tensorflow as tf
-    import numpy as np
-
-    alpha_sort = deepbooru_opts['alpha_sort']
-    use_spaces = deepbooru_opts['use_spaces']
-    use_escape = deepbooru_opts['use_escape']
-    include_ranks = deepbooru_opts['include_ranks']
-
-    width = model.input_shape[2]
-    height = model.input_shape[1]
-    image = np.array(pil_image)
-    image = tf.image.resize(
-        image,
-        size=(height, width),
-        method=tf.image.ResizeMethod.AREA,
-        preserve_aspect_ratio=True,
-    )
-    image = image.numpy()  # EagerTensor to np.array
-    image = dd.image.transform_and_pad_image(image, width, height)
-    image = image / 255.0
-    image_shape = image.shape
-    image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
-
-    y = model.predict(image)[0]
-
-    result_dict = {}
-
-    for i, tag in enumerate(tags):
-        result_dict[tag] = y[i]
-
-    unsorted_tags_in_theshold = []
-    result_tags_print = []
-    for tag in tags:
-        if result_dict[tag] >= threshold:
             if tag.startswith("rating:"):
                 continue
-            unsorted_tags_in_theshold.append((result_dict[tag], tag))
-            result_tags_print.append(f'{result_dict[tag]} {tag}')
 
-    # sort tags
-    result_tags_out = []
-    sort_ndx = 0
-    if alpha_sort:
-        sort_ndx = 1
+            probability_dict[tag] = probability
 
-    # sort by reverse by likelihood and normal for alpha, and format tag text as requested
-    unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
-    for weight, tag in unsorted_tags_in_theshold:
-        tag_outformat = tag
-        if use_spaces:
-            tag_outformat = tag_outformat.replace('_', ' ')
-        if use_escape:
-            tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
-        if include_ranks:
-            tag_outformat = f"({tag_outformat}:{weight:.3f})"
+        if alpha_sort:
+            tags = sorted(probability_dict)
+        else:
+            tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
 
-        result_tags_out.append(tag_outformat)
+        res = []
 
-    print('\n'.join(sorted(result_tags_print, reverse=True)))
+        filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
 
-    return ', '.join(result_tags_out)
+        for tag in [x for x in tags if x not in filtertags]:
+            probability = probability_dict[tag]
+            tag_outformat = tag
+            if use_spaces:
+                tag_outformat = tag_outformat.replace('_', ' ')
+            if use_escape:
+                tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
+            if include_ranks:
+                tag_outformat = f"({tag_outformat}:{probability:.3f})"
+
+            res.append(tag_outformat)
+
+        return ", ".join(res)
+
+
+model = DeepDanbooru()
diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py
new file mode 100644
index 00000000..edd40c81
--- /dev/null
+++ b/modules/deepbooru_model.py
@@ -0,0 +1,676 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
+
+
+class DeepDanbooruModel(nn.Module):
+    def __init__(self):
+        super(DeepDanbooruModel, self).__init__()
+
+        self.tags = []
+
+        self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
+        self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
+        self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+        self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
+        self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+        self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+        self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
+        self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+        self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+        self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
+        self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
+        self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
+        self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
+        self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
+        self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
+        self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+        self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+        self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+        self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+        self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+        self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+        self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+        self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+        self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+        self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+        self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+        self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+        self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+        self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+        self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+        self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+        self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+        self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+        self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+        self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
+        self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
+        self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
+        self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
+        self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
+        self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
+        self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
+        self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
+        self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
+        self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
+        self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
+        self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
+        self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
+        self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
+        self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+        self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
+        self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
+        self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+        self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
+        self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
+        self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
+        self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
+        self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
+        self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
+        self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+        self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
+        self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
+        self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+        self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
+        self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
+        self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
+        self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
+
+    def forward(self, *inputs):
+        t_358, = inputs
+        t_359 = t_358.permute(*[0, 3, 1, 2])
+        t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
+        t_360 = self.n_Conv_0(t_359_padded)
+        t_361 = F.relu(t_360)
+        t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
+        t_362 = self.n_MaxPool_0(t_361)
+        t_363 = self.n_Conv_1(t_362)
+        t_364 = self.n_Conv_2(t_362)
+        t_365 = F.relu(t_364)
+        t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
+        t_366 = self.n_Conv_3(t_365_padded)
+        t_367 = F.relu(t_366)
+        t_368 = self.n_Conv_4(t_367)
+        t_369 = torch.add(t_368, t_363)
+        t_370 = F.relu(t_369)
+        t_371 = self.n_Conv_5(t_370)
+        t_372 = F.relu(t_371)
+        t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
+        t_373 = self.n_Conv_6(t_372_padded)
+        t_374 = F.relu(t_373)
+        t_375 = self.n_Conv_7(t_374)
+        t_376 = torch.add(t_375, t_370)
+        t_377 = F.relu(t_376)
+        t_378 = self.n_Conv_8(t_377)
+        t_379 = F.relu(t_378)
+        t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
+        t_380 = self.n_Conv_9(t_379_padded)
+        t_381 = F.relu(t_380)
+        t_382 = self.n_Conv_10(t_381)
+        t_383 = torch.add(t_382, t_377)
+        t_384 = F.relu(t_383)
+        t_385 = self.n_Conv_11(t_384)
+        t_386 = self.n_Conv_12(t_384)
+        t_387 = F.relu(t_386)
+        t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
+        t_388 = self.n_Conv_13(t_387_padded)
+        t_389 = F.relu(t_388)
+        t_390 = self.n_Conv_14(t_389)
+        t_391 = torch.add(t_390, t_385)
+        t_392 = F.relu(t_391)
+        t_393 = self.n_Conv_15(t_392)
+        t_394 = F.relu(t_393)
+        t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
+        t_395 = self.n_Conv_16(t_394_padded)
+        t_396 = F.relu(t_395)
+        t_397 = self.n_Conv_17(t_396)
+        t_398 = torch.add(t_397, t_392)
+        t_399 = F.relu(t_398)
+        t_400 = self.n_Conv_18(t_399)
+        t_401 = F.relu(t_400)
+        t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
+        t_402 = self.n_Conv_19(t_401_padded)
+        t_403 = F.relu(t_402)
+        t_404 = self.n_Conv_20(t_403)
+        t_405 = torch.add(t_404, t_399)
+        t_406 = F.relu(t_405)
+        t_407 = self.n_Conv_21(t_406)
+        t_408 = F.relu(t_407)
+        t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
+        t_409 = self.n_Conv_22(t_408_padded)
+        t_410 = F.relu(t_409)
+        t_411 = self.n_Conv_23(t_410)
+        t_412 = torch.add(t_411, t_406)
+        t_413 = F.relu(t_412)
+        t_414 = self.n_Conv_24(t_413)
+        t_415 = F.relu(t_414)
+        t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
+        t_416 = self.n_Conv_25(t_415_padded)
+        t_417 = F.relu(t_416)
+        t_418 = self.n_Conv_26(t_417)
+        t_419 = torch.add(t_418, t_413)
+        t_420 = F.relu(t_419)
+        t_421 = self.n_Conv_27(t_420)
+        t_422 = F.relu(t_421)
+        t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
+        t_423 = self.n_Conv_28(t_422_padded)
+        t_424 = F.relu(t_423)
+        t_425 = self.n_Conv_29(t_424)
+        t_426 = torch.add(t_425, t_420)
+        t_427 = F.relu(t_426)
+        t_428 = self.n_Conv_30(t_427)
+        t_429 = F.relu(t_428)
+        t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
+        t_430 = self.n_Conv_31(t_429_padded)
+        t_431 = F.relu(t_430)
+        t_432 = self.n_Conv_32(t_431)
+        t_433 = torch.add(t_432, t_427)
+        t_434 = F.relu(t_433)
+        t_435 = self.n_Conv_33(t_434)
+        t_436 = F.relu(t_435)
+        t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
+        t_437 = self.n_Conv_34(t_436_padded)
+        t_438 = F.relu(t_437)
+        t_439 = self.n_Conv_35(t_438)
+        t_440 = torch.add(t_439, t_434)
+        t_441 = F.relu(t_440)
+        t_442 = self.n_Conv_36(t_441)
+        t_443 = self.n_Conv_37(t_441)
+        t_444 = F.relu(t_443)
+        t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
+        t_445 = self.n_Conv_38(t_444_padded)
+        t_446 = F.relu(t_445)
+        t_447 = self.n_Conv_39(t_446)
+        t_448 = torch.add(t_447, t_442)
+        t_449 = F.relu(t_448)
+        t_450 = self.n_Conv_40(t_449)
+        t_451 = F.relu(t_450)
+        t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
+        t_452 = self.n_Conv_41(t_451_padded)
+        t_453 = F.relu(t_452)
+        t_454 = self.n_Conv_42(t_453)
+        t_455 = torch.add(t_454, t_449)
+        t_456 = F.relu(t_455)
+        t_457 = self.n_Conv_43(t_456)
+        t_458 = F.relu(t_457)
+        t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
+        t_459 = self.n_Conv_44(t_458_padded)
+        t_460 = F.relu(t_459)
+        t_461 = self.n_Conv_45(t_460)
+        t_462 = torch.add(t_461, t_456)
+        t_463 = F.relu(t_462)
+        t_464 = self.n_Conv_46(t_463)
+        t_465 = F.relu(t_464)
+        t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
+        t_466 = self.n_Conv_47(t_465_padded)
+        t_467 = F.relu(t_466)
+        t_468 = self.n_Conv_48(t_467)
+        t_469 = torch.add(t_468, t_463)
+        t_470 = F.relu(t_469)
+        t_471 = self.n_Conv_49(t_470)
+        t_472 = F.relu(t_471)
+        t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
+        t_473 = self.n_Conv_50(t_472_padded)
+        t_474 = F.relu(t_473)
+        t_475 = self.n_Conv_51(t_474)
+        t_476 = torch.add(t_475, t_470)
+        t_477 = F.relu(t_476)
+        t_478 = self.n_Conv_52(t_477)
+        t_479 = F.relu(t_478)
+        t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
+        t_480 = self.n_Conv_53(t_479_padded)
+        t_481 = F.relu(t_480)
+        t_482 = self.n_Conv_54(t_481)
+        t_483 = torch.add(t_482, t_477)
+        t_484 = F.relu(t_483)
+        t_485 = self.n_Conv_55(t_484)
+        t_486 = F.relu(t_485)
+        t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
+        t_487 = self.n_Conv_56(t_486_padded)
+        t_488 = F.relu(t_487)
+        t_489 = self.n_Conv_57(t_488)
+        t_490 = torch.add(t_489, t_484)
+        t_491 = F.relu(t_490)
+        t_492 = self.n_Conv_58(t_491)
+        t_493 = F.relu(t_492)
+        t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
+        t_494 = self.n_Conv_59(t_493_padded)
+        t_495 = F.relu(t_494)
+        t_496 = self.n_Conv_60(t_495)
+        t_497 = torch.add(t_496, t_491)
+        t_498 = F.relu(t_497)
+        t_499 = self.n_Conv_61(t_498)
+        t_500 = F.relu(t_499)
+        t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
+        t_501 = self.n_Conv_62(t_500_padded)
+        t_502 = F.relu(t_501)
+        t_503 = self.n_Conv_63(t_502)
+        t_504 = torch.add(t_503, t_498)
+        t_505 = F.relu(t_504)
+        t_506 = self.n_Conv_64(t_505)
+        t_507 = F.relu(t_506)
+        t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
+        t_508 = self.n_Conv_65(t_507_padded)
+        t_509 = F.relu(t_508)
+        t_510 = self.n_Conv_66(t_509)
+        t_511 = torch.add(t_510, t_505)
+        t_512 = F.relu(t_511)
+        t_513 = self.n_Conv_67(t_512)
+        t_514 = F.relu(t_513)
+        t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
+        t_515 = self.n_Conv_68(t_514_padded)
+        t_516 = F.relu(t_515)
+        t_517 = self.n_Conv_69(t_516)
+        t_518 = torch.add(t_517, t_512)
+        t_519 = F.relu(t_518)
+        t_520 = self.n_Conv_70(t_519)
+        t_521 = F.relu(t_520)
+        t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
+        t_522 = self.n_Conv_71(t_521_padded)
+        t_523 = F.relu(t_522)
+        t_524 = self.n_Conv_72(t_523)
+        t_525 = torch.add(t_524, t_519)
+        t_526 = F.relu(t_525)
+        t_527 = self.n_Conv_73(t_526)
+        t_528 = F.relu(t_527)
+        t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
+        t_529 = self.n_Conv_74(t_528_padded)
+        t_530 = F.relu(t_529)
+        t_531 = self.n_Conv_75(t_530)
+        t_532 = torch.add(t_531, t_526)
+        t_533 = F.relu(t_532)
+        t_534 = self.n_Conv_76(t_533)
+        t_535 = F.relu(t_534)
+        t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
+        t_536 = self.n_Conv_77(t_535_padded)
+        t_537 = F.relu(t_536)
+        t_538 = self.n_Conv_78(t_537)
+        t_539 = torch.add(t_538, t_533)
+        t_540 = F.relu(t_539)
+        t_541 = self.n_Conv_79(t_540)
+        t_542 = F.relu(t_541)
+        t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
+        t_543 = self.n_Conv_80(t_542_padded)
+        t_544 = F.relu(t_543)
+        t_545 = self.n_Conv_81(t_544)
+        t_546 = torch.add(t_545, t_540)
+        t_547 = F.relu(t_546)
+        t_548 = self.n_Conv_82(t_547)
+        t_549 = F.relu(t_548)
+        t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
+        t_550 = self.n_Conv_83(t_549_padded)
+        t_551 = F.relu(t_550)
+        t_552 = self.n_Conv_84(t_551)
+        t_553 = torch.add(t_552, t_547)
+        t_554 = F.relu(t_553)
+        t_555 = self.n_Conv_85(t_554)
+        t_556 = F.relu(t_555)
+        t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
+        t_557 = self.n_Conv_86(t_556_padded)
+        t_558 = F.relu(t_557)
+        t_559 = self.n_Conv_87(t_558)
+        t_560 = torch.add(t_559, t_554)
+        t_561 = F.relu(t_560)
+        t_562 = self.n_Conv_88(t_561)
+        t_563 = F.relu(t_562)
+        t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
+        t_564 = self.n_Conv_89(t_563_padded)
+        t_565 = F.relu(t_564)
+        t_566 = self.n_Conv_90(t_565)
+        t_567 = torch.add(t_566, t_561)
+        t_568 = F.relu(t_567)
+        t_569 = self.n_Conv_91(t_568)
+        t_570 = F.relu(t_569)
+        t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
+        t_571 = self.n_Conv_92(t_570_padded)
+        t_572 = F.relu(t_571)
+        t_573 = self.n_Conv_93(t_572)
+        t_574 = torch.add(t_573, t_568)
+        t_575 = F.relu(t_574)
+        t_576 = self.n_Conv_94(t_575)
+        t_577 = F.relu(t_576)
+        t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
+        t_578 = self.n_Conv_95(t_577_padded)
+        t_579 = F.relu(t_578)
+        t_580 = self.n_Conv_96(t_579)
+        t_581 = torch.add(t_580, t_575)
+        t_582 = F.relu(t_581)
+        t_583 = self.n_Conv_97(t_582)
+        t_584 = F.relu(t_583)
+        t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
+        t_585 = self.n_Conv_98(t_584_padded)
+        t_586 = F.relu(t_585)
+        t_587 = self.n_Conv_99(t_586)
+        t_588 = self.n_Conv_100(t_582)
+        t_589 = torch.add(t_587, t_588)
+        t_590 = F.relu(t_589)
+        t_591 = self.n_Conv_101(t_590)
+        t_592 = F.relu(t_591)
+        t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
+        t_593 = self.n_Conv_102(t_592_padded)
+        t_594 = F.relu(t_593)
+        t_595 = self.n_Conv_103(t_594)
+        t_596 = torch.add(t_595, t_590)
+        t_597 = F.relu(t_596)
+        t_598 = self.n_Conv_104(t_597)
+        t_599 = F.relu(t_598)
+        t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
+        t_600 = self.n_Conv_105(t_599_padded)
+        t_601 = F.relu(t_600)
+        t_602 = self.n_Conv_106(t_601)
+        t_603 = torch.add(t_602, t_597)
+        t_604 = F.relu(t_603)
+        t_605 = self.n_Conv_107(t_604)
+        t_606 = F.relu(t_605)
+        t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
+        t_607 = self.n_Conv_108(t_606_padded)
+        t_608 = F.relu(t_607)
+        t_609 = self.n_Conv_109(t_608)
+        t_610 = torch.add(t_609, t_604)
+        t_611 = F.relu(t_610)
+        t_612 = self.n_Conv_110(t_611)
+        t_613 = F.relu(t_612)
+        t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
+        t_614 = self.n_Conv_111(t_613_padded)
+        t_615 = F.relu(t_614)
+        t_616 = self.n_Conv_112(t_615)
+        t_617 = torch.add(t_616, t_611)
+        t_618 = F.relu(t_617)
+        t_619 = self.n_Conv_113(t_618)
+        t_620 = F.relu(t_619)
+        t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
+        t_621 = self.n_Conv_114(t_620_padded)
+        t_622 = F.relu(t_621)
+        t_623 = self.n_Conv_115(t_622)
+        t_624 = torch.add(t_623, t_618)
+        t_625 = F.relu(t_624)
+        t_626 = self.n_Conv_116(t_625)
+        t_627 = F.relu(t_626)
+        t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
+        t_628 = self.n_Conv_117(t_627_padded)
+        t_629 = F.relu(t_628)
+        t_630 = self.n_Conv_118(t_629)
+        t_631 = torch.add(t_630, t_625)
+        t_632 = F.relu(t_631)
+        t_633 = self.n_Conv_119(t_632)
+        t_634 = F.relu(t_633)
+        t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
+        t_635 = self.n_Conv_120(t_634_padded)
+        t_636 = F.relu(t_635)
+        t_637 = self.n_Conv_121(t_636)
+        t_638 = torch.add(t_637, t_632)
+        t_639 = F.relu(t_638)
+        t_640 = self.n_Conv_122(t_639)
+        t_641 = F.relu(t_640)
+        t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
+        t_642 = self.n_Conv_123(t_641_padded)
+        t_643 = F.relu(t_642)
+        t_644 = self.n_Conv_124(t_643)
+        t_645 = torch.add(t_644, t_639)
+        t_646 = F.relu(t_645)
+        t_647 = self.n_Conv_125(t_646)
+        t_648 = F.relu(t_647)
+        t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
+        t_649 = self.n_Conv_126(t_648_padded)
+        t_650 = F.relu(t_649)
+        t_651 = self.n_Conv_127(t_650)
+        t_652 = torch.add(t_651, t_646)
+        t_653 = F.relu(t_652)
+        t_654 = self.n_Conv_128(t_653)
+        t_655 = F.relu(t_654)
+        t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
+        t_656 = self.n_Conv_129(t_655_padded)
+        t_657 = F.relu(t_656)
+        t_658 = self.n_Conv_130(t_657)
+        t_659 = torch.add(t_658, t_653)
+        t_660 = F.relu(t_659)
+        t_661 = self.n_Conv_131(t_660)
+        t_662 = F.relu(t_661)
+        t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
+        t_663 = self.n_Conv_132(t_662_padded)
+        t_664 = F.relu(t_663)
+        t_665 = self.n_Conv_133(t_664)
+        t_666 = torch.add(t_665, t_660)
+        t_667 = F.relu(t_666)
+        t_668 = self.n_Conv_134(t_667)
+        t_669 = F.relu(t_668)
+        t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
+        t_670 = self.n_Conv_135(t_669_padded)
+        t_671 = F.relu(t_670)
+        t_672 = self.n_Conv_136(t_671)
+        t_673 = torch.add(t_672, t_667)
+        t_674 = F.relu(t_673)
+        t_675 = self.n_Conv_137(t_674)
+        t_676 = F.relu(t_675)
+        t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
+        t_677 = self.n_Conv_138(t_676_padded)
+        t_678 = F.relu(t_677)
+        t_679 = self.n_Conv_139(t_678)
+        t_680 = torch.add(t_679, t_674)
+        t_681 = F.relu(t_680)
+        t_682 = self.n_Conv_140(t_681)
+        t_683 = F.relu(t_682)
+        t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
+        t_684 = self.n_Conv_141(t_683_padded)
+        t_685 = F.relu(t_684)
+        t_686 = self.n_Conv_142(t_685)
+        t_687 = torch.add(t_686, t_681)
+        t_688 = F.relu(t_687)
+        t_689 = self.n_Conv_143(t_688)
+        t_690 = F.relu(t_689)
+        t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
+        t_691 = self.n_Conv_144(t_690_padded)
+        t_692 = F.relu(t_691)
+        t_693 = self.n_Conv_145(t_692)
+        t_694 = torch.add(t_693, t_688)
+        t_695 = F.relu(t_694)
+        t_696 = self.n_Conv_146(t_695)
+        t_697 = F.relu(t_696)
+        t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
+        t_698 = self.n_Conv_147(t_697_padded)
+        t_699 = F.relu(t_698)
+        t_700 = self.n_Conv_148(t_699)
+        t_701 = torch.add(t_700, t_695)
+        t_702 = F.relu(t_701)
+        t_703 = self.n_Conv_149(t_702)
+        t_704 = F.relu(t_703)
+        t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
+        t_705 = self.n_Conv_150(t_704_padded)
+        t_706 = F.relu(t_705)
+        t_707 = self.n_Conv_151(t_706)
+        t_708 = torch.add(t_707, t_702)
+        t_709 = F.relu(t_708)
+        t_710 = self.n_Conv_152(t_709)
+        t_711 = F.relu(t_710)
+        t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
+        t_712 = self.n_Conv_153(t_711_padded)
+        t_713 = F.relu(t_712)
+        t_714 = self.n_Conv_154(t_713)
+        t_715 = torch.add(t_714, t_709)
+        t_716 = F.relu(t_715)
+        t_717 = self.n_Conv_155(t_716)
+        t_718 = F.relu(t_717)
+        t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
+        t_719 = self.n_Conv_156(t_718_padded)
+        t_720 = F.relu(t_719)
+        t_721 = self.n_Conv_157(t_720)
+        t_722 = torch.add(t_721, t_716)
+        t_723 = F.relu(t_722)
+        t_724 = self.n_Conv_158(t_723)
+        t_725 = self.n_Conv_159(t_723)
+        t_726 = F.relu(t_725)
+        t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
+        t_727 = self.n_Conv_160(t_726_padded)
+        t_728 = F.relu(t_727)
+        t_729 = self.n_Conv_161(t_728)
+        t_730 = torch.add(t_729, t_724)
+        t_731 = F.relu(t_730)
+        t_732 = self.n_Conv_162(t_731)
+        t_733 = F.relu(t_732)
+        t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
+        t_734 = self.n_Conv_163(t_733_padded)
+        t_735 = F.relu(t_734)
+        t_736 = self.n_Conv_164(t_735)
+        t_737 = torch.add(t_736, t_731)
+        t_738 = F.relu(t_737)
+        t_739 = self.n_Conv_165(t_738)
+        t_740 = F.relu(t_739)
+        t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
+        t_741 = self.n_Conv_166(t_740_padded)
+        t_742 = F.relu(t_741)
+        t_743 = self.n_Conv_167(t_742)
+        t_744 = torch.add(t_743, t_738)
+        t_745 = F.relu(t_744)
+        t_746 = self.n_Conv_168(t_745)
+        t_747 = self.n_Conv_169(t_745)
+        t_748 = F.relu(t_747)
+        t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
+        t_749 = self.n_Conv_170(t_748_padded)
+        t_750 = F.relu(t_749)
+        t_751 = self.n_Conv_171(t_750)
+        t_752 = torch.add(t_751, t_746)
+        t_753 = F.relu(t_752)
+        t_754 = self.n_Conv_172(t_753)
+        t_755 = F.relu(t_754)
+        t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
+        t_756 = self.n_Conv_173(t_755_padded)
+        t_757 = F.relu(t_756)
+        t_758 = self.n_Conv_174(t_757)
+        t_759 = torch.add(t_758, t_753)
+        t_760 = F.relu(t_759)
+        t_761 = self.n_Conv_175(t_760)
+        t_762 = F.relu(t_761)
+        t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
+        t_763 = self.n_Conv_176(t_762_padded)
+        t_764 = F.relu(t_763)
+        t_765 = self.n_Conv_177(t_764)
+        t_766 = torch.add(t_765, t_760)
+        t_767 = F.relu(t_766)
+        t_768 = self.n_Conv_178(t_767)
+        t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
+        t_770 = torch.squeeze(t_769, 3)
+        t_770 = torch.squeeze(t_770, 2)
+        t_771 = torch.sigmoid(t_770)
+        return t_771
+
+    def load_state_dict(self, state_dict, **kwargs):
+        self.tags = state_dict.get('tags', [])
+
+        super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
+
diff --git a/modules/devices.py b/modules/devices.py
index eb422583..800510b7 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,62 +1,96 @@
+import sys, os, shlex
 import contextlib
-
 import torch
-
 from modules import errors
+from packaging import version
 
-# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
-has_mps = getattr(torch, 'has_mps', False)
 
-cpu = torch.device("cpu")
+# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
+# check `getattr` and try it for compatibility
+def has_mps() -> bool:
+    if not getattr(torch, 'has_mps', False):
+        return False
+    try:
+        torch.zeros(1).to(torch.device("mps"))
+        return True
+    except Exception:
+        return False
+
+
+def extract_device_id(args, name):
+    for x in range(len(args)):
+        if name in args[x]:
+            return args[x + 1]
+
+    return None
+
+
+def get_cuda_device_string():
+    from modules import shared
+
+    if shared.cmd_opts.device_id is not None:
+        return f"cuda:{shared.cmd_opts.device_id}"
+
+    return "cuda"
 
 
 def get_optimal_device():
     if torch.cuda.is_available():
-        return torch.device("cuda")
+        return torch.device(get_cuda_device_string())
 
-    if has_mps:
+    if has_mps():
         return torch.device("mps")
 
     return cpu
 
 
+def get_device_for(task):
+    from modules import shared
+
+    if task in shared.cmd_opts.use_cpu:
+        return cpu
+
+    return get_optimal_device()
+
+
 def torch_gc():
     if torch.cuda.is_available():
-        torch.cuda.empty_cache()
-        torch.cuda.ipc_collect()
+        with torch.cuda.device(get_cuda_device_string()):
+            torch.cuda.empty_cache()
+            torch.cuda.ipc_collect()
 
 
 def enable_tf32():
     if torch.cuda.is_available():
+
+        # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
+        # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
+        if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
+            torch.backends.cudnn.benchmark = True
+
         torch.backends.cuda.matmul.allow_tf32 = True
         torch.backends.cudnn.allow_tf32 = True
 
 
+
 errors.run(enable_tf32, "Enabling TF32")
 
-device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+cpu = torch.device("cpu")
+device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
 dtype = torch.float16
 dtype_vae = torch.float16
 
-def randn(seed, shape):
-    # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
-    if device.type == 'mps':
-        generator = torch.Generator(device=cpu)
-        generator.manual_seed(seed)
-        noise = torch.randn(shape, generator=generator, device=cpu).to(device)
-        return noise
 
+def randn(seed, shape):
     torch.manual_seed(seed)
+    if device.type == 'mps':
+        return torch.randn(shape, device=cpu).to(device)
     return torch.randn(shape, device=device)
 
 
 def randn_without_seed(shape):
-    # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
     if device.type == 'mps':
-        generator = torch.Generator(device=cpu)
-        noise = torch.randn(shape, generator=generator, device=cpu).to(device)
-        return noise
-
+        return torch.randn(shape, device=cpu).to(device)
     return torch.randn(shape, device=device)
 
 
@@ -70,3 +104,37 @@ def autocast(disable=False):
         return contextlib.nullcontext()
 
     return torch.autocast("cuda")
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+orig_tensor_to = torch.Tensor.to
+def tensor_to_fix(self, *args, **kwargs):
+    if self.device.type != 'mps' and \
+       ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
+       (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
+        self = self.contiguous()
+    return orig_tensor_to(self, *args, **kwargs)
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 
+orig_layer_norm = torch.nn.functional.layer_norm
+def layer_norm_fix(*args, **kwargs):
+    if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
+        args = list(args)
+        args[0] = args[0].contiguous()
+    return orig_layer_norm(*args, **kwargs)
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
+orig_tensor_numpy = torch.Tensor.numpy
+def numpy_fix(self, *args, **kwargs):
+    if self.requires_grad:
+        self = self.detach()
+    return orig_tensor_numpy(self, *args, **kwargs)
+
+
+# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
+if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
+    torch.Tensor.to = tensor_to_fix
+    torch.nn.functional.layer_norm = layer_norm_fix
+    torch.Tensor.numpy = numpy_fix
diff --git a/modules/errors.py b/modules/errors.py
index 372dc51a..a668c014 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -2,9 +2,30 @@ import sys
 import traceback
 
 
+def print_error_explanation(message):
+    lines = message.strip().split("\n")
+    max_len = max([len(x) for x in lines])
+
+    print('=' * max_len, file=sys.stderr)
+    for line in lines:
+        print(line, file=sys.stderr)
+    print('=' * max_len, file=sys.stderr)
+
+
+def display(e: Exception, task):
+    print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
+    print(traceback.format_exc(), file=sys.stderr)
+
+    message = str(e)
+    if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
+        print_error_explanation("""
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
+        """)
+
+
 def run(code, task):
     try:
         code()
     except Exception as e:
-        print(f"{task}: {type(e).__name__}", file=sys.stderr)
-        print(traceback.format_exc(), file=sys.stderr)
+        display(task, e)
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 46ad0da3..9a9c38f1 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -11,62 +11,118 @@ from modules.upscaler import Upscaler, UpscalerData
 from modules.shared import opts
 
 
-def fix_model_layers(crt_model, pretrained_net):
-    # this code is adapted from https://github.com/xinntao/ESRGAN
-    if 'conv_first.weight' in pretrained_net:
-        return pretrained_net
 
-    if 'model.0.weight' not in pretrained_net:
-        is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
-        if is_realesrgan:
-            raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
-        else:
-            raise Exception("The file is not a ESRGAN model.")
+def mod2normal(state_dict):
+    # this code is copied from https://github.com/victorca25/iNNfer
+    if 'conv_first.weight' in state_dict:
+        crt_net = {}
+        items = []
+        for k, v in state_dict.items():
+            items.append(k)
 
-    crt_net = crt_model.state_dict()
-    load_net_clean = {}
-    for k, v in pretrained_net.items():
-        if k.startswith('module.'):
-            load_net_clean[k[7:]] = v
-        else:
-            load_net_clean[k] = v
-    pretrained_net = load_net_clean
+        crt_net['model.0.weight'] = state_dict['conv_first.weight']
+        crt_net['model.0.bias'] = state_dict['conv_first.bias']
 
-    tbd = []
-    for k, v in crt_net.items():
-        tbd.append(k)
+        for k in items.copy():
+            if 'RDB' in k:
+                ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+                if '.weight' in k:
+                    ori_k = ori_k.replace('.weight', '.0.weight')
+                elif '.bias' in k:
+                    ori_k = ori_k.replace('.bias', '.0.bias')
+                crt_net[ori_k] = state_dict[k]
+                items.remove(k)
 
-    # directly copy
-    for k, v in crt_net.items():
-        if k in pretrained_net and pretrained_net[k].size() == v.size():
-            crt_net[k] = pretrained_net[k]
-            tbd.remove(k)
+        crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
+        crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
+        crt_net['model.3.weight'] = state_dict['upconv1.weight']
+        crt_net['model.3.bias'] = state_dict['upconv1.bias']
+        crt_net['model.6.weight'] = state_dict['upconv2.weight']
+        crt_net['model.6.bias'] = state_dict['upconv2.bias']
+        crt_net['model.8.weight'] = state_dict['HRconv.weight']
+        crt_net['model.8.bias'] = state_dict['HRconv.bias']
+        crt_net['model.10.weight'] = state_dict['conv_last.weight']
+        crt_net['model.10.bias'] = state_dict['conv_last.bias']
+        state_dict = crt_net
+    return state_dict
 
-    crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
-    crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
 
-    for k in tbd.copy():
-        if 'RDB' in k:
-            ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
-            if '.weight' in k:
-                ori_k = ori_k.replace('.weight', '.0.weight')
-            elif '.bias' in k:
-                ori_k = ori_k.replace('.bias', '.0.bias')
-            crt_net[k] = pretrained_net[ori_k]
-            tbd.remove(k)
+def resrgan2normal(state_dict, nb=23):
+    # this code is copied from https://github.com/victorca25/iNNfer
+    if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
+        re8x = 0
+        crt_net = {}
+        items = []
+        for k, v in state_dict.items():
+            items.append(k)
 
-    crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
-    crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
-    crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
-    crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
-    crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
-    crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
-    crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
-    crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
-    crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
-    crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
+        crt_net['model.0.weight'] = state_dict['conv_first.weight']
+        crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+        for k in items.copy():
+            if "rdb" in k:
+                ori_k = k.replace('body.', 'model.1.sub.')
+                ori_k = ori_k.replace('.rdb', '.RDB')
+                if '.weight' in k:
+                    ori_k = ori_k.replace('.weight', '.0.weight')
+                elif '.bias' in k:
+                    ori_k = ori_k.replace('.bias', '.0.bias')
+                crt_net[ori_k] = state_dict[k]
+                items.remove(k)
+
+        crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
+        crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
+        crt_net['model.3.weight'] = state_dict['conv_up1.weight']
+        crt_net['model.3.bias'] = state_dict['conv_up1.bias']
+        crt_net['model.6.weight'] = state_dict['conv_up2.weight']
+        crt_net['model.6.bias'] = state_dict['conv_up2.bias']
+
+        if 'conv_up3.weight' in state_dict:
+            # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
+            re8x = 3
+            crt_net['model.9.weight'] = state_dict['conv_up3.weight']
+            crt_net['model.9.bias'] = state_dict['conv_up3.bias']
+
+        crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
+        crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
+        crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
+        crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
+
+        state_dict = crt_net
+    return state_dict
+
+
+def infer_params(state_dict):
+    # this code is copied from https://github.com/victorca25/iNNfer
+    scale2x = 0
+    scalemin = 6
+    n_uplayer = 0
+    plus = False
+
+    for block in list(state_dict):
+        parts = block.split(".")
+        n_parts = len(parts)
+        if n_parts == 5 and parts[2] == "sub":
+            nb = int(parts[3])
+        elif n_parts == 3:
+            part_num = int(parts[1])
+            if (part_num > scalemin
+                and parts[0] == "model"
+                and parts[2] == "weight"):
+                scale2x += 1
+            if part_num > n_uplayer:
+                n_uplayer = part_num
+                out_nc = state_dict[block].shape[0]
+        if not plus and "conv1x1" in block:
+            plus = True
+
+    nf = state_dict["model.0.weight"].shape[0]
+    in_nc = state_dict["model.0.weight"].shape[1]
+    out_nc = out_nc
+    scale = 2 ** scale2x
+
+    return in_nc, out_nc, nf, nb, plus, scale
 
-    return crt_net
 
 class UpscalerESRGAN(Upscaler):
     def __init__(self, dirname):
@@ -109,20 +165,39 @@ class UpscalerESRGAN(Upscaler):
             print("Unable to load %s from %s" % (self.model_path, filename))
             return None
 
-        pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
-        crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
+        state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
 
-        pretrained_net = fix_model_layers(crt_model, pretrained_net)
-        crt_model.load_state_dict(pretrained_net)
-        crt_model.eval()
+        if "params_ema" in state_dict:
+            state_dict = state_dict["params_ema"]
+        elif "params" in state_dict:
+            state_dict = state_dict["params"]
+            num_conv = 16 if "realesr-animevideov3" in filename else 32
+            model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
+            model.load_state_dict(state_dict)
+            model.eval()
+            return model
 
-        return crt_model
+        if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
+            nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
+            state_dict = resrgan2normal(state_dict, nb)
+        elif "conv_first.weight" in state_dict:
+            state_dict = mod2normal(state_dict)
+        elif "model.0.weight" not in state_dict:
+            raise Exception("The file is not a recognized ESRGAN model.")
+
+        in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
+
+        model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
+        model.load_state_dict(state_dict)
+        model.eval()
+
+        return model
 
 
 def upscale_without_tiling(model, img):
     img = np.array(img)
     img = img[:, :, ::-1]
-    img = np.moveaxis(img, 2, 0) / 255
+    img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
     img = torch.from_numpy(img).float()
     img = img.unsqueeze(0).to(devices.device_esrgan)
     with torch.no_grad():
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
index e413d36e..bc9ceb2a 100644
--- a/modules/esrgan_model_arch.py
+++ b/modules/esrgan_model_arch.py
@@ -1,80 +1,463 @@
-# this file is taken from https://github.com/xinntao/ESRGAN
+# this file is adapted from https://github.com/victorca25/iNNfer
 
+import math
 import functools
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
 
-def make_layer(block, n_layers):
-    layers = []
-    for _ in range(n_layers):
-        layers.append(block())
-    return nn.Sequential(*layers)
+####################
+# RRDBNet Generator
+####################
 
+class RRDBNet(nn.Module):
+    def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
+            act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
+            finalact=None, gaussian_noise=False, plus=False):
+        super(RRDBNet, self).__init__()
+        n_upscale = int(math.log(upscale, 2))
+        if upscale == 3:
+            n_upscale = 1
 
-class ResidualDenseBlock_5C(nn.Module):
-    def __init__(self, nf=64, gc=32, bias=True):
-        super(ResidualDenseBlock_5C, self).__init__()
-        # gc: growth channel, i.e. intermediate channels
-        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
-        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
-        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
-        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
-        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
-        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        self.resrgan_scale = 0
+        if in_nc % 16 == 0:
+            self.resrgan_scale = 1
+        elif in_nc != 4 and in_nc % 4 == 0:
+            self.resrgan_scale = 2
 
-        # initialization
-        # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+        fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+        rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+            norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
+            gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
+        LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
 
-    def forward(self, x):
-        x1 = self.lrelu(self.conv1(x))
-        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
-        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
-        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
-        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
-        return x5 * 0.2 + x
+        if upsample_mode == 'upconv':
+            upsample_block = upconv_block
+        elif upsample_mode == 'pixelshuffle':
+            upsample_block = pixelshuffle_block
+        else:
+            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
+        if upscale == 3:
+            upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
+        else:
+            upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
+        HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
+        HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+
+        outact = act(finalact) if finalact else None
+
+        self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
+            *upsampler, HR_conv0, HR_conv1, outact)
+
+    def forward(self, x, outm=None):
+        if self.resrgan_scale == 1:
+            feat = pixel_unshuffle(x, scale=4)
+        elif self.resrgan_scale == 2:
+            feat = pixel_unshuffle(x, scale=2)
+        else:
+            feat = x
+
+        return self.model(feat)
 
 
 class RRDB(nn.Module):
-    '''Residual in Residual Dense Block'''
+    """
+    Residual in Residual Dense Block
+    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
+    """
 
-    def __init__(self, nf, gc=32):
+    def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+            spectral_norm=False, gaussian_noise=False, plus=False):
         super(RRDB, self).__init__()
-        self.RDB1 = ResidualDenseBlock_5C(nf, gc)
-        self.RDB2 = ResidualDenseBlock_5C(nf, gc)
-        self.RDB3 = ResidualDenseBlock_5C(nf, gc)
+        # This is for backwards compatibility with existing models
+        if nr == 3:
+            self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+                    gaussian_noise=gaussian_noise, plus=plus)
+            self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+                    gaussian_noise=gaussian_noise, plus=plus)
+            self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+                    gaussian_noise=gaussian_noise, plus=plus)
+        else:
+            RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+                                              norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+                                              gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
+            self.RDBs = nn.Sequential(*RDB_list)
 
     def forward(self, x):
-        out = self.RDB1(x)
-        out = self.RDB2(out)
-        out = self.RDB3(out)
+        if hasattr(self, 'RDB1'):
+            out = self.RDB1(x)
+            out = self.RDB2(out)
+            out = self.RDB3(out)
+        else:
+            out = self.RDBs(x)
         return out * 0.2 + x
 
 
-class RRDBNet(nn.Module):
-    def __init__(self, in_nc, out_nc, nf, nb, gc=32):
-        super(RRDBNet, self).__init__()
-        RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
+class ResidualDenseBlock_5C(nn.Module):
+    """
+    Residual Dense Block
+    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
+    Modified options that can be used:
+        - "Partial Convolution based Padding" arXiv:1811.11718
+        - "Spectral normalization" arXiv:1802.05957
+        - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. 
+            {Rakotonirina} and A. {Rasoanaivo}
+    """
 
-        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
-        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
-        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
-        #### upsampling
-        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
-        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
-        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
-        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
+    def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+            spectral_norm=False, gaussian_noise=False, plus=False):
+        super(ResidualDenseBlock_5C, self).__init__()
 
-        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        self.noise = GaussianNoise() if gaussian_noise else None
+        self.conv1x1 = conv1x1(nf, gc) if plus else None
+
+        self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+            spectral_norm=spectral_norm)
+        self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+            spectral_norm=spectral_norm)
+        self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+            spectral_norm=spectral_norm)
+        self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+            spectral_norm=spectral_norm)
+        if mode == 'CNA':
+            last_act = None
+        else:
+            last_act = act_type
+        self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
+            norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
+            spectral_norm=spectral_norm)
 
     def forward(self, x):
-        fea = self.conv_first(x)
-        trunk = self.trunk_conv(self.RRDB_trunk(fea))
-        fea = fea + trunk
+        x1 = self.conv1(x)
+        x2 = self.conv2(torch.cat((x, x1), 1))
+        if self.conv1x1:
+            x2 = x2 + self.conv1x1(x)
+        x3 = self.conv3(torch.cat((x, x1, x2), 1))
+        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
+        if self.conv1x1:
+            x4 = x4 + x2
+        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+        if self.noise:
+            return self.noise(x5.mul(0.2) + x)
+        else:
+            return x5 * 0.2 + x
 
-        fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
-        fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
-        out = self.conv_last(self.lrelu(self.HRconv(fea)))
 
+####################
+# ESRGANplus
+####################
+
+class GaussianNoise(nn.Module):
+    def __init__(self, sigma=0.1, is_relative_detach=False):
+        super().__init__()
+        self.sigma = sigma
+        self.is_relative_detach = is_relative_detach
+        self.noise = torch.tensor(0, dtype=torch.float)
+
+    def forward(self, x):
+        if self.training and self.sigma != 0:
+            self.noise = self.noise.to(x.device)
+            scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
+            sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
+            x = x + sampled_noise
+        return x 
+
+def conv1x1(in_planes, out_planes, stride=1):
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+####################
+# SRVGGNetCompact
+####################
+
+class SRVGGNetCompact(nn.Module):
+    """A compact VGG-style network structure for super-resolution.
+    This class is copied from https://github.com/xinntao/Real-ESRGAN
+    """
+
+    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+        super(SRVGGNetCompact, self).__init__()
+        self.num_in_ch = num_in_ch
+        self.num_out_ch = num_out_ch
+        self.num_feat = num_feat
+        self.num_conv = num_conv
+        self.upscale = upscale
+        self.act_type = act_type
+
+        self.body = nn.ModuleList()
+        # the first conv
+        self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+        # the first activation
+        if act_type == 'relu':
+            activation = nn.ReLU(inplace=True)
+        elif act_type == 'prelu':
+            activation = nn.PReLU(num_parameters=num_feat)
+        elif act_type == 'leakyrelu':
+            activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+        self.body.append(activation)
+
+        # the body structure
+        for _ in range(num_conv):
+            self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+            # activation
+            if act_type == 'relu':
+                activation = nn.ReLU(inplace=True)
+            elif act_type == 'prelu':
+                activation = nn.PReLU(num_parameters=num_feat)
+            elif act_type == 'leakyrelu':
+                activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+            self.body.append(activation)
+
+        # the last conv
+        self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+        # upsample
+        self.upsampler = nn.PixelShuffle(upscale)
+
+    def forward(self, x):
+        out = x
+        for i in range(0, len(self.body)):
+            out = self.body[i](out)
+
+        out = self.upsampler(out)
+        # add the nearest upsampled image, so that the network learns the residual
+        base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+        out += base
         return out
+
+
+####################
+# Upsampler
+####################
+
+class Upsample(nn.Module):
+    r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
+    The input data is assumed to be of the form
+    `minibatch x channels x [optional depth] x [optional height] x width`.
+    """
+
+    def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
+        super(Upsample, self).__init__()
+        if isinstance(scale_factor, tuple):
+            self.scale_factor = tuple(float(factor) for factor in scale_factor)
+        else:
+            self.scale_factor = float(scale_factor) if scale_factor else None
+        self.mode = mode
+        self.size = size
+        self.align_corners = align_corners
+
+    def forward(self, x):
+        return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
+
+    def extra_repr(self):
+        if self.scale_factor is not None:
+            info = 'scale_factor=' + str(self.scale_factor)
+        else:
+            info = 'size=' + str(self.size)
+        info += ', mode=' + self.mode
+        return info
+
+
+def pixel_unshuffle(x, scale):
+    """ Pixel unshuffle.
+    Args:
+        x (Tensor): Input feature with shape (b, c, hh, hw).
+        scale (int): Downsample ratio.
+    Returns:
+        Tensor: the pixel unshuffled feature.
+    """
+    b, c, hh, hw = x.size()
+    out_channel = c * (scale**2)
+    assert hh % scale == 0 and hw % scale == 0
+    h = hh // scale
+    w = hw // scale
+    x_view = x.view(b, c, h, scale, w, scale)
+    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+                        pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
+    """
+    Pixel shuffle layer
+    (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
+    Neural Network, CVPR17)
+    """
+    conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
+                        pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
+    pixel_shuffle = nn.PixelShuffle(upscale_factor)
+
+    n = norm(norm_type, out_nc) if norm_type else None
+    a = act(act_type) if act_type else None
+    return sequential(conv, pixel_shuffle, n, a)
+
+
+def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+                pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
+    """ Upconv layer """
+    upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
+    upsample = Upsample(scale_factor=upscale_factor, mode=mode)
+    conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
+                        pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
+    return sequential(upsample, conv)
+
+
+
+
+
+
+
+
+####################
+# Basic blocks
+####################
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+    """Make layers by stacking the same blocks.
+    Args:
+        basic_block (nn.module): nn.module class for basic block. (block)
+        num_basic_block (int): number of blocks. (n_layers)
+    Returns:
+        nn.Sequential: Stacked blocks in nn.Sequential.
+    """
+    layers = []
+    for _ in range(num_basic_block):
+        layers.append(basic_block(**kwarg))
+    return nn.Sequential(*layers)
+
+
+def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
+    """ activation helper """
+    act_type = act_type.lower()
+    if act_type == 'relu':
+        layer = nn.ReLU(inplace)
+    elif act_type in ('leakyrelu', 'lrelu'):
+        layer = nn.LeakyReLU(neg_slope, inplace)
+    elif act_type == 'prelu':
+        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
+    elif act_type == 'tanh':  # [-1, 1] range output
+        layer = nn.Tanh()
+    elif act_type == 'sigmoid':  # [0, 1] range output
+        layer = nn.Sigmoid()
+    else:
+        raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
+    return layer
+
+
+class Identity(nn.Module):
+    def __init__(self, *kwargs):
+        super(Identity, self).__init__()
+
+    def forward(self, x, *kwargs):
+        return x
+
+
+def norm(norm_type, nc):
+    """ Return a normalization layer """
+    norm_type = norm_type.lower()
+    if norm_type == 'batch':
+        layer = nn.BatchNorm2d(nc, affine=True)
+    elif norm_type == 'instance':
+        layer = nn.InstanceNorm2d(nc, affine=False)
+    elif norm_type == 'none':
+        def norm_layer(x): return Identity()
+    else:
+        raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
+    return layer
+
+
+def pad(pad_type, padding):
+    """ padding layer helper """
+    pad_type = pad_type.lower()
+    if padding == 0:
+        return None
+    if pad_type == 'reflect':
+        layer = nn.ReflectionPad2d(padding)
+    elif pad_type == 'replicate':
+        layer = nn.ReplicationPad2d(padding)
+    elif pad_type == 'zero':
+        layer = nn.ZeroPad2d(padding)
+    else:
+        raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
+    return layer
+
+
+def get_valid_padding(kernel_size, dilation):
+    kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
+    padding = (kernel_size - 1) // 2
+    return padding
+
+
+class ShortcutBlock(nn.Module):
+    """ Elementwise sum the output of a submodule to its input """
+    def __init__(self, submodule):
+        super(ShortcutBlock, self).__init__()
+        self.sub = submodule
+
+    def forward(self, x):
+        output = x + self.sub(x)
+        return output
+
+    def __repr__(self):
+        return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
+
+
+def sequential(*args):
+    """ Flatten Sequential. It unwraps nn.Sequential. """
+    if len(args) == 1:
+        if isinstance(args[0], OrderedDict):
+            raise NotImplementedError('sequential does not support OrderedDict input.')
+        return args[0]  # No sequential is needed.
+    modules = []
+    for module in args:
+        if isinstance(module, nn.Sequential):
+            for submodule in module.children():
+                modules.append(submodule)
+        elif isinstance(module, nn.Module):
+            modules.append(module)
+    return nn.Sequential(*modules)
+
+
+def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
+               pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
+               spectral_norm=False):
+    """ Conv layer with padding, normalization, activation """
+    assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
+    padding = get_valid_padding(kernel_size, dilation)
+    p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
+    padding = padding if pad_type == 'zero' else 0
+
+    if convtype=='PartialConv2D':
+        c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+               dilation=dilation, bias=bias, groups=groups)
+    elif convtype=='DeformConv2D':
+        c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+               dilation=dilation, bias=bias, groups=groups)
+    elif convtype=='Conv3D':
+        c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+                dilation=dilation, bias=bias, groups=groups)
+    else:
+        c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+                dilation=dilation, bias=bias, groups=groups)
+
+    if spectral_norm:
+        c = nn.utils.spectral_norm(c)
+
+    a = act(act_type) if act_type else None
+    if 'CNA' in mode:
+        n = norm(norm_type, out_nc) if norm_type else None
+        return sequential(p, c, n, a)
+    elif mode == 'NAC':
+        if norm_type is None and act_type is not None:
+            a = act(act_type, inplace=False)
+        n = norm(norm_type, in_nc) if norm_type else None
+        return sequential(n, a, p, c)
diff --git a/modules/extensions.py b/modules/extensions.py
new file mode 100644
index 00000000..b522125c
--- /dev/null
+++ b/modules/extensions.py
@@ -0,0 +1,99 @@
+import os
+import sys
+import traceback
+
+import git
+
+from modules import paths, shared
+
+extensions = []
+extensions_dir = os.path.join(paths.script_path, "extensions")
+extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
+
+
+def active():
+    return [x for x in extensions if x.enabled]
+
+
+class Extension:
+    def __init__(self, name, path, enabled=True, is_builtin=False):
+        self.name = name
+        self.path = path
+        self.enabled = enabled
+        self.status = ''
+        self.can_update = False
+        self.is_builtin = is_builtin
+
+        repo = None
+        try:
+            if os.path.exists(os.path.join(path, ".git")):
+                repo = git.Repo(path)
+        except Exception:
+            print(f"Error reading github repository info from {path}:", file=sys.stderr)
+            print(traceback.format_exc(), file=sys.stderr)
+
+        if repo is None or repo.bare:
+            self.remote = None
+        else:
+            try:
+                self.remote = next(repo.remote().urls, None)
+                self.status = 'unknown'
+            except Exception:
+                self.remote = None
+
+    def list_files(self, subdir, extension):
+        from modules import scripts
+
+        dirpath = os.path.join(self.path, subdir)
+        if not os.path.isdir(dirpath):
+            return []
+
+        res = []
+        for filename in sorted(os.listdir(dirpath)):
+            res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
+
+        res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+
+        return res
+
+    def check_updates(self):
+        repo = git.Repo(self.path)
+        for fetch in repo.remote().fetch("--dry-run"):
+            if fetch.flags != fetch.HEAD_UPTODATE:
+                self.can_update = True
+                self.status = "behind"
+                return
+
+        self.can_update = False
+        self.status = "latest"
+
+    def fetch_and_reset_hard(self):
+        repo = git.Repo(self.path)
+        # Fix: `error: Your local changes to the following files would be overwritten by merge`,
+        # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
+        repo.git.fetch('--all')
+        repo.git.reset('--hard', 'origin')
+
+
+def list_extensions():
+    extensions.clear()
+
+    if not os.path.isdir(extensions_dir):
+        return
+
+    paths = []
+    for dirname in [extensions_dir, extensions_builtin_dir]:
+        if not os.path.isdir(dirname):
+            return
+
+        for extension_dirname in sorted(os.listdir(dirname)):
+            path = os.path.join(dirname, extension_dirname)
+            if not os.path.isdir(path):
+                continue
+
+            paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
+
+    for dirname, path, is_builtin in paths:
+        extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
+        extensions.append(extension)
+
diff --git a/modules/extras.py b/modules/extras.py
index b853fa5b..d665440a 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
 import math
 import os
+import sys
+import traceback
 
 import numpy as np
 from PIL import Image
@@ -7,7 +10,11 @@ from PIL import Image
 import torch
 import tqdm
 
-from modules import processing, shared, images, devices, sd_models
+from typing import Callable, List, OrderedDict, Tuple
+from functools import partial
+from dataclasses import dataclass
+
+from modules import processing, shared, images, devices, sd_models, sd_samplers
 from modules.shared import opts
 import modules.gfpgan_model
 from modules.ui import plaintext_to_html
@@ -15,19 +22,50 @@ import modules.codeformer_model
 import piexif
 import piexif.helper
 import gradio as gr
+import safetensors.torch
+
+class LruCache(OrderedDict):
+    @dataclass(frozen=True)
+    class Key:
+        image_hash: int
+        info_hash: int
+        args_hash: int
+
+    @dataclass
+    class Value:
+        image: Image.Image
+        info: str
+
+    def __init__(self, max_size: int = 5, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._max_size = max_size
+
+    def get(self, key: LruCache.Key) -> LruCache.Value:
+        ret = super().get(key)
+        if ret is not None:
+            self.move_to_end(key)  # Move to end of eviction list
+        return ret
+
+    def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
+        self[key] = value
+        while len(self) > self._max_size:
+            self.popitem(last=False)
 
 
-cached_images = {}
+cached_images: LruCache = LruCache(max_size=5)
 
 
-def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
+def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
     devices.torch_gc()
 
+    shared.state.begin()
+    shared.state.job = 'extras'
+
     imageArr = []
     # Also keep track of original file names
     imageNameArr = []
     outputs = []
-    
+
     if extras_mode == 1:
         #convert file to pillow image
         for img in image_folder:
@@ -39,9 +77,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
 
         if input_dir == '':
             return outputs, "Please select an input directory.", ''
-        image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
+        image_list = shared.listfiles(input_dir)
         for img in image_list:
-            image = Image.open(img)
+            try:
+                image = Image.open(img)
+            except Exception:
+                continue
             imageArr.append(image)
             imageNameArr.append(img)
     else:
@@ -53,80 +94,128 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
     else:
         outpath = opts.outdir_samples or opts.outdir_extras_samples
 
-    
-    for image, image_name in zip(imageArr, imageNameArr):
-        if image is None:
-            return outputs, "Please select an input image.", ''
-        existing_pnginfo = image.info or {}
+    # Extra operation definitions
 
-        image = image.convert("RGB")
-        info = ""
+    def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+        shared.state.job = 'extras-gfpgan'
+        restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
+        res = Image.fromarray(restored_img)
 
-        if gfpgan_visibility > 0:
-            restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
-            res = Image.fromarray(restored_img)
+        if gfpgan_visibility < 1.0:
+            res = Image.blend(image, res, gfpgan_visibility)
 
-            if gfpgan_visibility < 1.0:
-                res = Image.blend(image, res, gfpgan_visibility)
+        info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
+        return (res, info)
 
-            info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
-            image = res
+    def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+        shared.state.job = 'extras-codeformer'
+        restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
+        res = Image.fromarray(restored_img)
 
-        if codeformer_visibility > 0:
-            restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
-            res = Image.fromarray(restored_img)
+        if codeformer_visibility < 1.0:
+            res = Image.blend(image, res, codeformer_visibility)
 
-            if codeformer_visibility < 1.0:
-                res = Image.blend(image, res, codeformer_visibility)
+        info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
+        return (res, info)
 
-            info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
-            image = res
+    def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
+        shared.state.job = 'extras-upscale'
+        upscaler = shared.sd_upscalers[scaler_index]
+        res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
+        if mode == 1 and crop:
+            cropped = Image.new("RGB", (resize_w, resize_h))
+            cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
+            res = cropped
+        return res
 
+    def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+        # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
+        nonlocal upscaling_resize
         if resize_mode == 1:
             upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
             crop_info = " (crop)" if upscaling_crop else ""
             info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
+        return (image, info)
 
-        if upscaling_resize != 1.0:
-            def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
-                small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
-                pixels = tuple(np.array(small).flatten().tolist())
-                key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight, 
-                       resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
+    @dataclass
+    class UpscaleParams:
+        upscaler_idx: int
+        blend_alpha: float
 
-                c = cached_images.get(key)
-                if c is None:
-                    upscaler = shared.sd_upscalers[scaler_index]
-                    c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
-                    if mode == 1 and crop:
-                        cropped = Image.new("RGB", (resize_w, resize_h))
-                        cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
-                        c = cropped
-                    cached_images[key] = c
+    def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+        blended_result: Image.Image = None
+        image_hash: str = hash(np.array(image.getdata()).tobytes())
+        for upscaler in params:
+            upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
+                            upscaling_resize_w, upscaling_resize_h, upscaling_crop)
+            cache_key = LruCache.Key(image_hash=image_hash,
+                                     info_hash=hash(info),
+                                     args_hash=hash(upscale_args))
+            cached_entry = cached_images.get(cache_key)
+            if cached_entry is None:
+                res = upscale(image, *upscale_args)
+                info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
+                cached_images.put(cache_key, LruCache.Value(image=res, info=info))
+            else:
+                res, info = cached_entry.image, cached_entry.info
 
-                return c
+            if blended_result is None:
+                blended_result = res
+            else:
+                blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
+        return (blended_result, info)
 
-            info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
-            res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
+    # Build a list of operations to run
+    facefix_ops: List[Callable] = []
+    facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
+    facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
 
-            if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
-                res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
-                info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
-                res = Image.blend(res, res2, extras_upscaler_2_visibility)
+    upscale_ops: List[Callable] = []
+    upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
 
-            image = res
+    if upscaling_resize != 0:
+        step_params: List[UpscaleParams] = []
+        step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
+        if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
+            step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
 
-        while len(cached_images) > 2:
-            del cached_images[next(iter(cached_images.keys()))]
+        upscale_ops.append(partial(run_upscalers_blend, step_params))
 
-        images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
-                          no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
-                          forced_filename=image_name if opts.use_original_name_batch else None)
+    extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
 
-        if opts.enable_pnginfo:
+    for image, image_name in zip(imageArr, imageNameArr):
+        if image is None:
+            return outputs, "Please select an input image.", ''
+
+        shared.state.textinfo = f'Processing image {image_name}'
+        
+        existing_pnginfo = image.info or {}
+
+        image = image.convert("RGB")
+        info = ""
+        # Run each operation on each image
+        for op in extras_ops:
+            image, info = op(image, info)
+
+        if opts.use_original_name_batch and image_name is not None:
+            basename = os.path.splitext(os.path.basename(image_name))[0]
+        else:
+            basename = ''
+
+        if opts.enable_pnginfo: # append info before save
             image.info = existing_pnginfo
             image.info["extras"] = info
 
+        if save_output:
+            # Add upscaler name as a suffix.
+            suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
+            # Add second upscaler if applicable.
+            if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
+                suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
+
+            images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
+                            no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
+
         if extras_mode != 2 or show_extras_results :
             outputs.append(image)
 
@@ -134,30 +223,16 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
 
     return outputs, plaintext_to_html(info), ''
 
+def clear_cache():
+    cached_images.clear()
+
 
 def run_pnginfo(image):
     if image is None:
         return '', '', ''
 
-    items = image.info
-    geninfo = ''
-
-    if "exif" in image.info:
-        exif = piexif.load(image.info["exif"])
-        exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
-        try:
-            exif_comment = piexif.helper.UserComment.load(exif_comment)
-        except ValueError:
-            exif_comment = exif_comment.decode('utf8', errors="ignore")
-
-        items['exif comment'] = exif_comment
-        geninfo = exif_comment
-
-        for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
-                      'loop', 'background', 'timestamp', 'duration']:
-            items.pop(field, None)
-
-    geninfo = items.get('parameters', geninfo)
+    geninfo, items = images.read_info_from_image(image)
+    items = {**{'parameters': geninfo}, **items}
 
     info = ''
     for key, text in items.items():
@@ -175,7 +250,10 @@ def run_pnginfo(image):
     return '', geninfo, info
 
 
-def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
+def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+    shared.state.begin()
+    shared.state.job = 'model-merge'
+
     def weighted_sum(theta0, theta1, alpha):
         return ((1 - alpha) * theta0) + (alpha * theta1)
 
@@ -187,23 +265,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
 
     primary_model_info = sd_models.checkpoints_list[primary_model_name]
     secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
-    teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
-
-    print(f"Loading {primary_model_info.filename}...")
-    primary_model = torch.load(primary_model_info.filename, map_location='cpu')
-    theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
-
-    print(f"Loading {secondary_model_info.filename}...")
-    secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
-    theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
-
-    if teritary_model_info is not None:
-        print(f"Loading {teritary_model_info.filename}...")
-        teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
-        theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
-    else:
-        teritary_model = None
-        theta_2 = None
+    tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
+    result_is_inpainting_model = False
 
     theta_funcs = {
         "Weighted sum": (None, weighted_sum),
@@ -211,9 +274,19 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
     }
     theta_func1, theta_func2 = theta_funcs[interp_method]
 
-    print(f"Merging...")
+    if theta_func1 and not tertiary_model_info:
+        shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
+        shared.state.end()
+        return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+
+    shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
+    print(f"Loading {secondary_model_info.filename}...")
+    theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
 
     if theta_func1:
+        print(f"Loading {tertiary_model_info.filename}...")
+        theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
+
         for key in tqdm.tqdm(theta_1.keys()):
             if 'model' in key:
                 if key in theta_2:
@@ -221,12 +294,33 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
                     theta_1[key] = theta_func1(theta_1[key], t2)
                 else:
                     theta_1[key] = torch.zeros_like(theta_1[key])
-    del theta_2, teritary_model
+        del theta_2
+
+    shared.state.textinfo = f"Loading {primary_model_info.filename}..."
+    print(f"Loading {primary_model_info.filename}...")
+    theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
+
+    print("Merging...")
 
     for key in tqdm.tqdm(theta_0.keys()):
         if 'model' in key and key in theta_1:
+            a = theta_0[key]
+            b = theta_1[key]
 
-            theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
+            shared.state.textinfo = f'Merging layer {key}'
+            # this enables merging an inpainting model (A) with another one (B);
+            # where normal model would have 4 channels, for latenst space, inpainting model would
+            # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
+            if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
+                if a.shape[1] == 4 and b.shape[1] == 9:
+                    raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
+
+                assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
+
+                theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
+                result_is_inpainting_model = True
+            else:
+                theta_0[key] = theta_func2(a, b, multiplier)
 
             if save_as_half:
                 theta_0[key] = theta_0[key].half()
@@ -237,17 +331,35 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
             theta_0[key] = theta_1[key]
             if save_as_half:
                 theta_0[key] = theta_0[key].half()
+    del theta_1
 
     ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
 
-    filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
-    filename = filename if custom_name == '' else (custom_name + '.ckpt')
+    filename = \
+        primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
+        secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
+        interp_method.replace(" ", "_") + \
+        '-merged.' +  \
+        ("inpainting." if result_is_inpainting_model else "") + \
+        checkpoint_format
+
+    filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
+
     output_modelname = os.path.join(ckpt_dir, filename)
 
+    shared.state.textinfo = f"Saving to {output_modelname}..."
     print(f"Saving to {output_modelname}...")
-    torch.save(primary_model, output_modelname)
+
+    _, extension = os.path.splitext(output_modelname)
+    if extension.lower() == ".safetensors":
+        safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
+    else:
+        torch.save(theta_0, output_modelname)
 
     sd_models.list_models()
 
-    print(f"Checkpoint saved.")
+    print("Checkpoint saved.")
+    shared.state.textinfo = "Checkpoint saved to " + output_modelname
+    shared.state.end()
+
     return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 98d24406..4baf4d9a 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,14 +1,222 @@
+import base64
+import io
+import math
 import os
 import re
+from pathlib import Path
+
 import gradio as gr
 from modules.shared import script_path
-from modules import shared
+from modules import shared, ui_tempdir
+import tempfile
+from PIL import Image
 
-re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
+re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
 re_param = re.compile(re_param_code)
 re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
 re_imagesize = re.compile(r"^(\d+)x(\d+)$")
+re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
 type_of_gr_update = type(gr.update())
+paste_fields = {}
+bind_list = []
+
+
+def reset():
+    paste_fields.clear()
+    bind_list.clear()
+
+
+def quote(text):
+    if ',' not in str(text):
+        return text
+
+    text = str(text)
+    text = text.replace('\\', '\\\\')
+    text = text.replace('"', '\\"')
+    return f'"{text}"'
+
+
+def image_from_url_text(filedata):
+    if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
+        filedata = filedata[0]
+
+    if type(filedata) == dict and filedata.get("is_file", False):
+        filename = filedata["name"]
+        is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
+        assert is_in_right_dir, 'trying to open image file outside of allowed directories'
+
+        return Image.open(filename)
+
+    if type(filedata) == list:
+        if len(filedata) == 0:
+            return None
+
+        filedata = filedata[0]
+
+    if filedata.startswith("data:image/png;base64,"):
+        filedata = filedata[len("data:image/png;base64,"):]
+
+    filedata = base64.decodebytes(filedata.encode('utf-8'))
+    image = Image.open(io.BytesIO(filedata))
+    return image
+
+
+def add_paste_fields(tabname, init_img, fields):
+    paste_fields[tabname] = {"init_img": init_img, "fields": fields}
+
+    # backwards compatibility for existing extensions
+    import modules.ui
+    if tabname == 'txt2img':
+        modules.ui.txt2img_paste_fields = fields
+    elif tabname == 'img2img':
+        modules.ui.img2img_paste_fields = fields
+
+
+def integrate_settings_paste_fields(component_dict):
+    from modules import ui
+
+    settings_map = {
+        'sd_hypernetwork': 'Hypernet',
+        'sd_hypernetwork_strength': 'Hypernet strength',
+        'CLIP_stop_at_last_layers': 'Clip skip',
+        'inpainting_mask_weight': 'Conditional mask weight',
+        'sd_model_checkpoint': 'Model hash',
+        'eta_noise_seed_delta': 'ENSD',
+        'initial_noise_multiplier': 'Noise multiplier',
+    }
+    settings_paste_fields = [
+        (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
+        for k, v in settings_map.items()
+    ]
+
+    for tabname, info in paste_fields.items():
+        if info["fields"] is not None:
+            info["fields"] += settings_paste_fields
+
+
+def create_buttons(tabs_list):
+    buttons = {}
+    for tab in tabs_list:
+        buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
+    return buttons
+
+
+#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
+def bind_buttons(buttons, send_image, send_generate_info):
+    bind_list.append([buttons, send_image, send_generate_info])
+
+
+def send_image_and_dimensions(x):
+    if isinstance(x, Image.Image):
+        img = x
+    else:
+        img = image_from_url_text(x)
+
+    if shared.opts.send_size and isinstance(img, Image.Image):
+        w = img.width
+        h = img.height
+    else:
+        w = gr.update()
+        h = gr.update()
+
+    return img, w, h
+
+
+def run_bind():
+    for buttons, source_image_component, send_generate_info in bind_list:
+        for tab in buttons:
+            button = buttons[tab]
+            destination_image_component = paste_fields[tab]["init_img"]
+            fields = paste_fields[tab]["fields"]
+
+            destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
+            destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
+
+            if source_image_component and destination_image_component:
+                if isinstance(source_image_component, gr.Gallery):
+                    func = send_image_and_dimensions if destination_width_component else image_from_url_text
+                    jsfunc = "extract_image_from_gallery"
+                else:
+                    func = send_image_and_dimensions if destination_width_component else lambda x: x
+                    jsfunc = None
+
+                button.click(
+                    fn=func,
+                    _js=jsfunc,
+                    inputs=[source_image_component],
+                    outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
+                )
+
+            if send_generate_info and fields is not None:
+                if send_generate_info in paste_fields:
+                    paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
+                    button.click(
+                        fn=lambda *x: x,
+                        inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
+                        outputs=[field for field, name in fields if name in paste_field_names],
+                    )
+                else:
+                    connect_paste(button, fields, send_generate_info)
+
+            button.click(
+                fn=None,
+                _js=f"switch_to_{tab}",
+                inputs=None,
+                outputs=None,
+            )
+
+
+def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
+    """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
+
+    Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
+    parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
+
+    If the infotext has no hash, then a hypernet with the same name will be selected instead.
+    """
+    hypernet_name = hypernet_name.lower()
+    if hypernet_hash is not None:
+        # Try to match the hash in the name
+        for hypernet_key in shared.hypernetworks.keys():
+            result = re_hypernet_hash.search(hypernet_key)
+            if result is not None and result[1] == hypernet_hash:
+                return hypernet_key
+    else:
+        # Fall back to a hypernet with the same name
+        for hypernet_key in shared.hypernetworks.keys():
+            if hypernet_key.lower().startswith(hypernet_name):
+                return hypernet_key
+
+    return None
+
+
+def restore_old_hires_fix_params(res):
+    """for infotexts that specify old First pass size parameter, convert it into
+    width, height, and hr scale"""
+
+    firstpass_width = res.get('First pass size-1', None)
+    firstpass_height = res.get('First pass size-2', None)
+
+    if firstpass_width is None or firstpass_height is None:
+        return
+
+    firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
+    width = int(res.get("Size-1", 512))
+    height = int(res.get("Size-2", 512))
+
+    if firstpass_width == 0 or firstpass_height == 0:
+        # old algorithm for auto-calculating first pass size
+        desired_pixel_count = 512 * 512
+        actual_pixel_count = width * height
+        scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+        firstpass_width = math.ceil(scale * width / 64) * 64
+        firstpass_height = math.ceil(scale * height / 64) * 64
+
+    hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
+
+    res['Size-1'] = firstpass_width
+    res['Size-2'] = firstpass_height
+    res['Hires upscale'] = hr_scale
 
 
 def parse_generation_parameters(x: str):
@@ -45,8 +253,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
         else:
             prompt += ("" if prompt == "" else "\n") + line
 
-        res["Prompt"] = prompt
-        res["Negative prompt"] = negative_prompt
+    res["Prompt"] = prompt
+    res["Negative prompt"] = negative_prompt
 
     for k, v in re_param.findall(lastline):
         m = re_imagesize.match(v)
@@ -56,10 +264,24 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
         else:
             res[k] = v
 
+    # Missing CLIP skip means it was set to 1 (the default)
+    if "Clip skip" not in res:
+        res["Clip skip"] = "1"
+
+    if "Hypernet strength" not in res:
+        res["Hypernet strength"] = "1"
+
+    if "Hypernet" in res:
+        hypernet_name = res["Hypernet"]
+        hypernet_hash = res.get("Hypernet hash", None)
+        res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
+
+    restore_old_hires_fix_params(res)
+
     return res
 
 
-def connect_paste(button, paste_fields, input_comp, js=None):
+def connect_paste(button, paste_fields, input_comp, jsfunc=None):
     def paste_func(prompt):
         if not prompt and not shared.cmd_opts.hide_ui_dir_config:
             filename = os.path.join(script_path, "params.txt")
@@ -83,7 +305,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
             else:
                 try:
                     valtype = type(output.value)
-                    val = valtype(v)
+
+                    if valtype == bool and v == "False":
+                        val = False
+                    else:
+                        val = valtype(v)
+
                     res.append(gr.update(value=val))
                 except Exception:
                     res.append(gr.update())
@@ -92,7 +319,9 @@ def connect_paste(button, paste_fields, input_comp, js=None):
 
     button.click(
         fn=paste_func,
-        _js=js,
+        _js=jsfunc,
         inputs=[input_comp],
         outputs=[x[0] for x in paste_fields],
     )
+
+
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index a9452dce..1e2dbc32 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -36,7 +36,9 @@ def gfpgann():
     else:
         print("Unable to load gfpgan model!")
         return None
-    model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
+    if hasattr(facexlib.detection.retinaface, 'device'):
+        facexlib.detection.retinaface.device = devices.device_gfpgan
+    model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
     loaded_gfpgan_model = model
 
     return model
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 74300122..450fecac 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,40 +1,72 @@
+import csv
 import datetime
 import glob
 import html
 import os
 import sys
 import traceback
-import tqdm
-import csv
+import inspect
 
-import torch
-
-from ldm.util import default
-from modules import devices, shared, processing, sd_models
-import torch
-from torch import einsum
-from einops import rearrange, repeat
 import modules.textual_inversion.dataset
+import torch
+import tqdm
+from einops import rearrange, repeat
+from ldm.util import default
+from modules import devices, processing, sd_models, shared, sd_samplers
 from modules.textual_inversion import textual_inversion
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from torch import einsum
+from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
 
+from collections import defaultdict, deque
+from statistics import stdev, mean
+
+
+optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
 
 class HypernetworkModule(torch.nn.Module):
     multiplier = 1.0
+    activation_dict = {
+        "linear": torch.nn.Identity,
+        "relu": torch.nn.ReLU,
+        "leakyrelu": torch.nn.LeakyReLU,
+        "elu": torch.nn.ELU,
+        "swish": torch.nn.Hardswish,
+        "tanh": torch.nn.Tanh,
+        "sigmoid": torch.nn.Sigmoid,
+    }
+    activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
 
-    def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
+    def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
+                 add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
         super().__init__()
 
-        assert layer_structure is not None, "layer_structure mut not be None"
+        assert layer_structure is not None, "layer_structure must not be None"
         assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
         assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
 
         linears = []
         for i in range(len(layer_structure) - 1):
+
+            # Add a fully-connected layer
             linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+
+            # Add an activation func except last layer
+            if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
+                pass
+            elif activation_func in self.activation_dict:
+                linears.append(self.activation_dict[activation_func]())
+            else:
+                raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+
+            # Add layer normalization
             if add_layer_norm:
                 linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
 
+            # Add dropout except last layer
+            if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
+                linears.append(torch.nn.Dropout(p=0.3))
+
         self.linear = torch.nn.Sequential(*linears)
 
         if state_dict is not None:
@@ -42,9 +74,25 @@ class HypernetworkModule(torch.nn.Module):
             self.load_state_dict(state_dict)
         else:
             for layer in self.linear:
-                layer.weight.data.normal_(mean=0.0, std=0.01)
-                layer.bias.data.zero_()
-
+                if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+                    w, b = layer.weight.data, layer.bias.data
+                    if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
+                        normal_(w, mean=0.0, std=0.01)
+                        normal_(b, mean=0.0, std=0)
+                    elif weight_init == 'XavierUniform':
+                        xavier_uniform_(w)
+                        zeros_(b)
+                    elif weight_init == 'XavierNormal':
+                        xavier_normal_(w)
+                        zeros_(b)
+                    elif weight_init == 'KaimingUniform':
+                        kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+                        zeros_(b)
+                    elif weight_init == 'KaimingNormal':
+                        kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+                        zeros_(b)
+                    else:
+                        raise KeyError(f"Key {weight_init} is not defined as initialization!")
         self.to(devices.device)
 
     def fix_old_state_dict(self, state_dict):
@@ -69,7 +117,8 @@ class HypernetworkModule(torch.nn.Module):
     def trainables(self):
         layer_structure = []
         for layer in self.linear:
-            layer_structure += [layer.weight, layer.bias]
+            if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+                layer_structure += [layer.weight, layer.bias]
         return layer_structure
 
 
@@ -81,7 +130,7 @@ class Hypernetwork:
     filename = None
     name = None
 
-    def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
+    def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
         self.filename = None
         self.name = name
         self.layers = {}
@@ -89,26 +138,48 @@ class Hypernetwork:
         self.sd_checkpoint = None
         self.sd_checkpoint_name = None
         self.layer_structure = layer_structure
+        self.activation_func = activation_func
+        self.weight_init = weight_init
         self.add_layer_norm = add_layer_norm
+        self.use_dropout = use_dropout
+        self.activate_output = activate_output
+        self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
+        self.optimizer_name = None
+        self.optimizer_state_dict = None
 
         for size in enable_sizes or []:
             self.layers[size] = (
-                HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
-                HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
+                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+                                   self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+                                   self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
             )
+        self.eval_mode()
 
     def weights(self):
         res = []
+        for k, layers in self.layers.items():
+            for layer in layers:
+                res += layer.parameters()
+        return res
 
+    def train_mode(self):
         for k, layers in self.layers.items():
             for layer in layers:
                 layer.train()
-                res += layer.trainables()
+                for param in layer.parameters():
+                    param.requires_grad = True
 
-        return res
+    def eval_mode(self):
+        for k, layers in self.layers.items():
+            for layer in layers:
+                layer.eval()
+                for param in layer.parameters():
+                    param.requires_grad = False
 
     def save(self, filename):
         state_dict = {}
+        optimizer_saved_dict = {}
 
         for k, v in self.layers.items():
             state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@@ -116,11 +187,23 @@ class Hypernetwork:
         state_dict['step'] = self.step
         state_dict['name'] = self.name
         state_dict['layer_structure'] = self.layer_structure
+        state_dict['activation_func'] = self.activation_func
         state_dict['is_layer_norm'] = self.add_layer_norm
+        state_dict['weight_initialization'] = self.weight_init
+        state_dict['use_dropout'] = self.use_dropout
         state_dict['sd_checkpoint'] = self.sd_checkpoint
         state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+        state_dict['activate_output'] = self.activate_output
+        state_dict['last_layer_dropout'] = self.last_layer_dropout
+
+        if self.optimizer_name is not None:
+            optimizer_saved_dict['optimizer_name'] = self.optimizer_name
 
         torch.save(state_dict, filename)
+        if shared.opts.save_optimizer_state and self.optimizer_state_dict:
+            optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
+            optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
+            torch.save(optimizer_saved_dict, filename + '.optim')
 
     def load(self, filename):
         self.filename = filename
@@ -130,13 +213,38 @@ class Hypernetwork:
         state_dict = torch.load(filename, map_location='cpu')
 
         self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+        print(self.layer_structure)
+        self.activation_func = state_dict.get('activation_func', None)
+        print(f"Activation function is {self.activation_func}")
+        self.weight_init = state_dict.get('weight_initialization', 'Normal')
+        print(f"Weight initialization is {self.weight_init}")
         self.add_layer_norm = state_dict.get('is_layer_norm', False)
+        print(f"Layer norm is set to {self.add_layer_norm}")
+        self.use_dropout = state_dict.get('use_dropout', False)
+        print(f"Dropout usage is set to {self.use_dropout}" )
+        self.activate_output = state_dict.get('activate_output', True)
+        print(f"Activate last layer is set to {self.activate_output}")
+        self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
+
+        optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
+        self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
+        print(f"Optimizer name is {self.optimizer_name}")
+        if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
+            self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+        else:
+            self.optimizer_state_dict = None
+        if self.optimizer_state_dict:
+            print("Loaded existing optimizer from checkpoint")
+        else:
+            print("No saved optimizer exists in checkpoint")
 
         for size, sd in state_dict.items():
             if type(size) == int:
                 self.layers[size] = (
-                    HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
-                    HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
+                    HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
+                                       self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+                    HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
+                                       self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
                 )
 
         self.name = state_dict.get('name', self.name)
@@ -147,15 +255,18 @@ class Hypernetwork:
 
 def list_hypernetworks(path):
     res = {}
-    for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
+    for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
         name = os.path.splitext(os.path.basename(filename))[0]
-        res[name] = filename
+        # Prevent a hypothetical "None.pt" from being listed.
+        if name != "None":
+            res[name + f"({sd_models.model_hash(filename)})"] = filename
     return res
 
 
 def load_hypernetwork(filename):
     path = shared.hypernetworks.get(filename, None)
-    if path is not None:
+    # Prevent any file named "None.pt" from being loaded.
+    if path is not None and filename != "None":
         print(f"Loading hypernetwork {filename}")
         try:
             shared.loaded_hypernetwork = Hypernetwork()
@@ -166,7 +277,7 @@ def load_hypernetwork(filename):
             print(traceback.format_exc(), file=sys.stderr)
     else:
         if shared.loaded_hypernetwork is not None:
-            print(f"Unloading hypernetwork")
+            print("Unloading hypernetwork")
 
         shared.loaded_hypernetwork = None
 
@@ -240,16 +351,77 @@ def stack_conds(conds):
     return torch.stack(conds)
 
 
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
-    assert hypernetwork_name, 'hypernetwork not selected'
+def statistics(data):
+    if len(data) < 2:
+        std = 0
+    else:
+        std = stdev(data)
+    total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
+    recent_data = data[-32:]
+    if len(recent_data) < 2:
+        std = 0
+    else:
+        std = stdev(recent_data)
+    recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
+    return total_information, recent_information
+
+
+def report_statistics(loss_info:dict):
+    keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
+    for key in keys:
+        try:
+            print("Loss statistics for file " + key)
+            info, recent = statistics(list(loss_info[key]))
+            print(info)
+            print(recent)
+        except Exception as e:
+            print(e)
+
+
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
+    # Remove illegal characters from name.
+    name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+
+    fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+    if not overwrite_old:
+        assert not os.path.exists(fn), f"file {fn} already exists"
+
+    if type(layer_structure) == str:
+        layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
+
+    hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
+        name=name,
+        enable_sizes=[int(x) for x in enable_sizes],
+        layer_structure=layer_structure,
+        activation_func=activation_func,
+        weight_init=weight_init,
+        add_layer_norm=add_layer_norm,
+        use_dropout=use_dropout,
+    )
+    hypernet.save(fn)
+
+    shared.reload_hypernetworks()
+
+    return fn
+
+
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
+    from modules import images
+
+    save_hypernetwork_every = save_hypernetwork_every or 0
+    create_image_every = create_image_every or 0
+    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
 
     path = shared.hypernetworks.get(hypernetwork_name, None)
     shared.loaded_hypernetwork = Hypernetwork()
     shared.loaded_hypernetwork.load(path)
 
+    shared.state.job = "train-hypernetwork"
     shared.state.textinfo = "Initializing hypernetwork training..."
     shared.state.job_count = steps
 
+    hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
     filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
 
     log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
@@ -267,126 +439,229 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
     else:
         images_dir = None
 
+    hypernetwork = shared.loaded_hypernetwork
+    checkpoint = sd_models.select_checkpoint()
+
+    initial_step = hypernetwork.step or 0
+    if initial_step >= steps:
+        shared.state.textinfo = "Model has already been trained beyond specified max steps"
+        return hypernetwork, filename
+
+    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
+
+    # dataset loading may take a while, so input validations and early returns should be done before this
     shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
-    with torch.autocast("cuda"):
-        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+
+    pin_memory = shared.opts.pin_memory
+
+    ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+    
+    latent_sampling_method = ds.latent_sampling_method
+
+    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
+
+    old_parallel_processing_allowed = shared.parallel_processing_allowed
+
     if unload:
+        shared.parallel_processing_allowed = False
         shared.sd_model.cond_stage_model.to(devices.cpu)
         shared.sd_model.first_stage_model.to(devices.cpu)
-
-    hypernetwork = shared.loaded_hypernetwork
+    
     weights = hypernetwork.weights()
-    for weight in weights:
-        weight.requires_grad = True
+    hypernetwork.train_mode()
 
-    losses = torch.zeros((32,))
+    # Here we use optimizer from saved HN, or we can specify as UI option.
+    if hypernetwork.optimizer_name in optimizer_dict:
+        optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
+        optimizer_name = hypernetwork.optimizer_name
+    else:
+        print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
+        optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
+        optimizer_name = 'AdamW'
+
+    if hypernetwork.optimizer_state_dict:  # This line must be changed if Optimizer type can be different from saved optimizer.
+        try:
+            optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
+        except RuntimeError as e:
+            print("Cannot resume from saved optimizer!")
+            print(e)
+
+    scaler = torch.cuda.amp.GradScaler()
+    
+    batch_size = ds.batch_size
+    gradient_step = ds.gradient_step
+    # n steps = batch_size * gradient_step * n image processed
+    steps_per_epoch = len(ds) // batch_size // gradient_step
+    max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
+    loss_step = 0
+    _loss_step = 0 #internal
+    # size = len(ds.indexes)
+    # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
+    # losses = torch.zeros((size,))
+    # previous_mean_losses = [0]
+    # previous_mean_loss = 0
+    # print("Mean loss of {} elements".format(size))
+
+    steps_without_grad = 0
 
     last_saved_file = "<none>"
     last_saved_image = "<none>"
+    forced_filename = "<none>"
 
-    ititial_step = hypernetwork.step or 0
-    if ititial_step > steps:
-        return hypernetwork, filename
+    pbar = tqdm.tqdm(total=steps - initial_step)
+    try:
+        for i in range((steps-initial_step) * gradient_step):
+            if scheduler.finished:
+                break
+            if shared.state.interrupted:
+                break
+            for j, batch in enumerate(dl):
+                # works as a drop_last=True for gradient accumulation
+                if j == max_steps_per_epoch:
+                    break
+                scheduler.apply(optimizer, hypernetwork.step)
+                if scheduler.finished:
+                    break
+                if shared.state.interrupted:
+                    break
 
-    scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
-    optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+                with devices.autocast():
+                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+                    if tag_drop_out != 0 or shuffle_tags:
+                        shared.sd_model.cond_stage_model.to(devices.device)
+                        c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
+                        shared.sd_model.cond_stage_model.to(devices.cpu)
+                    else:
+                        c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
+                    loss = shared.sd_model(x, c)[0] / gradient_step
+                    del x
+                    del c
 
-    pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
-    for i, entries in pbar:
-        hypernetwork.step = i + ititial_step
+                    _loss_step += loss.item()
+                scaler.scale(loss).backward()
+                # go back until we reach gradient accumulation steps
+                if (j + 1) % gradient_step != 0:
+                    continue
+                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
+                # scaler.unscale_(optimizer)
+                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
+                # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
+                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
+                scaler.step(optimizer)
+                scaler.update()
+                hypernetwork.step += 1
+                pbar.update()
+                optimizer.zero_grad(set_to_none=True)
+                loss_step = _loss_step
+                _loss_step = 0
 
-        scheduler.apply(optimizer, hypernetwork.step)
-        if scheduler.finished:
-            break
+                steps_done = hypernetwork.step + 1
+                
+                epoch_num = hypernetwork.step // steps_per_epoch
+                epoch_step = hypernetwork.step % steps_per_epoch
 
-        if shared.state.interrupted:
-            break
+                pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+                if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
+                    # Before saving, change name to match current checkpoint.
+                    hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
+                    last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
+                    hypernetwork.optimizer_name = optimizer_name
+                    if shared.opts.save_optimizer_state:
+                        hypernetwork.optimizer_state_dict = optimizer.state_dict()
+                    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
+                    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.
 
-        with torch.autocast("cuda"):
-            c = stack_conds([entry.cond for entry in entries]).to(devices.device)
-            # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
-            x = torch.stack([entry.latent for entry in entries]).to(devices.device)
-            loss = shared.sd_model(x, c)[0]
-            del x
-            del c
+                textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
+                    "loss": f"{loss_step:.7f}",
+                    "learn_rate": scheduler.learn_rate
+                })
 
-            losses[hypernetwork.step % losses.shape[0]] = loss.item()
+                if images_dir is not None and steps_done % create_image_every == 0:
+                    forced_filename = f'{hypernetwork_name}-{steps_done}'
+                    last_saved_image = os.path.join(images_dir, forced_filename)
+                    hypernetwork.eval_mode()
+                    shared.sd_model.cond_stage_model.to(devices.device)
+                    shared.sd_model.first_stage_model.to(devices.device)
 
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
-        mean_loss = losses.mean()
-        if torch.isnan(mean_loss):
-            raise RuntimeError("Loss diverged.")
-        pbar.set_description(f"loss: {mean_loss:.7f}")
+                    p = processing.StableDiffusionProcessingTxt2Img(
+                        sd_model=shared.sd_model,
+                        do_not_save_grid=True,
+                        do_not_save_samples=True,
+                    )
 
-        if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
-            last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
-            hypernetwork.save(last_saved_file)
+                    if preview_from_txt2img:
+                        p.prompt = preview_prompt
+                        p.negative_prompt = preview_negative_prompt
+                        p.steps = preview_steps
+                        p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+                        p.cfg_scale = preview_cfg_scale
+                        p.seed = preview_seed
+                        p.width = preview_width
+                        p.height = preview_height
+                    else:
+                        p.prompt = batch.cond_text[0]
+                        p.steps = 20
+                        p.width = training_width
+                        p.height = training_height
 
-        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
-            "loss": f"{mean_loss:.7f}",
-            "learn_rate": scheduler.learn_rate
-        })
+                    preview_text = p.prompt
 
-        if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
-            last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+                    processed = processing.process_images(p)
+                    image = processed.images[0] if len(processed.images) > 0 else None
 
-            optimizer.zero_grad()
-            shared.sd_model.cond_stage_model.to(devices.device)
-            shared.sd_model.first_stage_model.to(devices.device)
+                    if unload:
+                        shared.sd_model.cond_stage_model.to(devices.cpu)
+                        shared.sd_model.first_stage_model.to(devices.cpu)
+                    hypernetwork.train_mode()
+                    if image is not None:
+                        shared.state.current_image = image
+                        last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+                        last_saved_image += f", prompt: {preview_text}"
 
-            p = processing.StableDiffusionProcessingTxt2Img(
-                sd_model=shared.sd_model,
-                do_not_save_grid=True,
-                do_not_save_samples=True,
-            )
+                shared.state.job_no = hypernetwork.step
 
-            if preview_from_txt2img:
-                p.prompt = preview_prompt
-                p.negative_prompt = preview_negative_prompt
-                p.steps = preview_steps
-                p.sampler_index = preview_sampler_index
-                p.cfg_scale = preview_cfg_scale
-                p.seed = preview_seed
-                p.width = preview_width
-                p.height = preview_height
-            else:
-                p.prompt = entries[0].cond_text
-                p.steps = 20
-
-            preview_text = p.prompt
-
-            processed = processing.process_images(p)
-            image = processed.images[0] if len(processed.images)>0 else None
-
-            if unload:
-                shared.sd_model.cond_stage_model.to(devices.cpu)
-                shared.sd_model.first_stage_model.to(devices.cpu)
-
-            if image is not None:
-                shared.state.current_image = image
-                image.save(last_saved_image)
-                last_saved_image += f", prompt: {preview_text}"
-
-        shared.state.job_no = hypernetwork.step
-
-        shared.state.textinfo = f"""
+                shared.state.textinfo = f"""
 <p>
-Loss: {mean_loss:.7f}<br/>
-Step: {hypernetwork.step}<br/>
-Last prompt: {html.escape(entries[0].cond_text)}<br/>
-Last saved embedding: {html.escape(last_saved_file)}<br/>
+Loss: {loss_step:.7f}<br/>
+Step: {steps_done}<br/>
+Last prompt: {html.escape(batch.cond_text[0])}<br/>
+Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
 Last saved image: {html.escape(last_saved_image)}<br/>
 </p>
 """
+    except Exception:
+        print(traceback.format_exc(), file=sys.stderr)
+    finally:
+        pbar.leave = False
+        pbar.close()
+        hypernetwork.eval_mode()
+        #report_statistics(loss_dict)
 
-    checkpoint = sd_models.select_checkpoint()
+    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+    hypernetwork.optimizer_name = optimizer_name
+    if shared.opts.save_optimizer_state:
+        hypernetwork.optimizer_state_dict = optimizer.state_dict()
+    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
 
-    hypernetwork.sd_checkpoint = checkpoint.hash
-    hypernetwork.sd_checkpoint_name = checkpoint.model_name
-    hypernetwork.save(filename)
+    del optimizer
+    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.
+    shared.sd_model.cond_stage_model.to(devices.device)
+    shared.sd_model.first_stage_model.to(devices.device)
+    shared.parallel_processing_allowed = old_parallel_processing_allowed
 
     return hypernetwork, filename
 
-
+def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
+    old_hypernetwork_name = hypernetwork.name
+    old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
+    old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
+    try:
+        hypernetwork.sd_checkpoint = checkpoint.hash
+        hypernetwork.sd_checkpoint_name = checkpoint.model_name
+        hypernetwork.name = hypernetwork_name
+        hypernetwork.save(filename)
+    except:
+        hypernetwork.sd_checkpoint = old_sd_checkpoint
+        hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
+        hypernetwork.name = old_hypernetwork_name
+        raise
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 08f75f15..e7f9e593 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -3,31 +3,16 @@ import os
 import re
 
 import gradio as gr
+import modules.hypernetworks.hypernetwork
+from modules import devices, sd_hijack, shared
 
-import modules.textual_inversion.textual_inversion
-import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared, devices
-from modules.hypernetworks import hypernetwork
+not_available = ["hardswish", "multiheadattention"]
+keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
 
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
+    filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
 
-def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
-    fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
-    assert not os.path.exists(fn), f"file {fn} already exists"
-
-    if type(layer_structure) == str:
-        layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure)))
-
-    hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
-        name=name,
-        enable_sizes=[int(x) for x in enable_sizes],
-        layer_structure=layer_structure,
-        add_layer_norm=add_layer_norm,
-    )
-    hypernet.save(fn)
-
-    shared.reload_hypernetworks()
-
-    return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
+    return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
 
 
 def train_hypernetwork(*args):
diff --git a/modules/images.py b/modules/images.py
index b9589563..c3a5fc8b 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,4 +1,8 @@
 import datetime
+import sys
+import traceback
+
+import pytz
 import io
 import math
 import os
@@ -11,8 +15,9 @@ import piexif.helper
 from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
 from fonts.ttf import Roboto
 import string
+import json
 
-from modules import sd_samplers, shared
+from modules import sd_samplers, shared, script_callbacks
 from modules.shared import opts, cmd_opts
 
 LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -34,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
 
     cols = math.ceil(len(imgs) / rows)
 
-    w, h = imgs[0].size
-    grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
+    params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
+    script_callbacks.image_grid_callback(params)
 
-    for i, img in enumerate(imgs):
-        grid.paste(img, box=(i % cols * w, i // cols * h))
+    w, h = imgs[0].size
+    grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
+
+    for i, img in enumerate(params.imgs):
+        grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
 
     return grid
 
@@ -131,8 +139,19 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
                 lines.append(word)
         return lines
 
-    def draw_texts(drawing, draw_x, draw_y, lines):
+    def get_font(fontsize):
+        try:
+            return ImageFont.truetype(opts.font or Roboto, fontsize)
+        except Exception:
+            return ImageFont.truetype(Roboto, fontsize)
+
+    def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
         for i, line in enumerate(lines):
+            fnt = initial_fnt
+            fontsize = initial_fontsize
+            while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
+                fontsize -= 1
+                fnt = get_font(fontsize)
             drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
 
             if not line.is_active:
@@ -143,10 +162,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
     fontsize = (width + height) // 25
     line_spacing = fontsize // 2
 
-    try:
-        fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
-    except Exception:
-        fnt = ImageFont.truetype(Roboto, fontsize)
+    fnt = get_font(fontsize)
 
     color_active = (0, 0, 0)
     color_inactive = (153, 153, 153)
@@ -173,6 +189,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
         for line in texts:
             bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
             line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
+            line.allowed_width = allowed_width
 
     hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
     ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
@@ -189,13 +206,13 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
         x = pad_left + width * col + width / 2
         y = pad_top / 2 - hor_text_heights[col] / 2
 
-        draw_texts(d, x, y, hor_texts[col])
+        draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
 
     for row in range(rows):
         x = pad_left / 2
         y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
 
-        draw_texts(d, x, y, ver_texts[row])
+        draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
 
     return result
 
@@ -213,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
     return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
 
 
-def resize_image(resize_mode, im, width, height):
+def resize_image(resize_mode, im, width, height, upscaler_name=None):
+    """
+    Resizes an image with the specified resize_mode, width, and height.
+
+    Args:
+        resize_mode: The mode to use when resizing the image.
+            0: Resize the image to the specified width and height.
+            1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
+            2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
+        im: The image to resize.
+        width: The width to resize the image to.
+        height: The height to resize the image to.
+        upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
+    """
+
+    upscaler_name = upscaler_name or opts.upscaler_for_img2img
+
     def resize(im, w, h):
-        if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
+        if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
             return im.resize((w, h), resample=LANCZOS)
 
         scale = max(w / im.width, h / im.height)
 
         if scale > 1.0:
-            upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
-            assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
+            upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
+            assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
 
             upscaler = upscalers[0]
             im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
@@ -273,10 +306,15 @@ invalid_filename_chars = '<>:"/\\|?*\n'
 invalid_filename_prefix = ' '
 invalid_filename_postfix = ' .'
 re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
+re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
+re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
 max_filename_part_length = 128
 
 
 def sanitize_filename_part(text, replace_spaces=True):
+    if text is None:
+        return None
+
     if replace_spaces:
         text = text.replace(' ', '_')
 
@@ -286,49 +324,105 @@ def sanitize_filename_part(text, replace_spaces=True):
     return text
 
 
-def apply_filename_pattern(x, p, seed, prompt):
-    max_prompt_words = opts.directories_max_prompt_words
+class FilenameGenerator:
+    replacements = {
+        'seed': lambda self: self.seed if self.seed is not None else '',
+        'steps': lambda self:  self.p and self.p.steps,
+        'cfg': lambda self: self.p and self.p.cfg_scale,
+        'width': lambda self: self.image.width,
+        'height': lambda self: self.image.height,
+        'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
+        'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
+        'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
+        'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
+        'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
+        'datetime': lambda self, *args: self.datetime(*args),  # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
+        'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
+        'prompt': lambda self: sanitize_filename_part(self.prompt),
+        'prompt_no_styles': lambda self: self.prompt_no_style(),
+        'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
+        'prompt_words': lambda self: self.prompt_words(),
+    }
+    default_time_format = '%Y%m%d%H%M%S'
 
-    if seed is not None:
-        x = x.replace("[seed]", str(seed))
+    def __init__(self, p, seed, prompt, image):
+        self.p = p
+        self.seed = seed
+        self.prompt = prompt
+        self.image = image
 
-    if p is not None:
-        x = x.replace("[steps]", str(p.steps))
-        x = x.replace("[cfg]", str(p.cfg_scale))
-        x = x.replace("[width]", str(p.width))
-        x = x.replace("[height]", str(p.height))
-        x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
-        x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
+    def prompt_no_style(self):
+        if self.p is None or self.prompt is None:
+            return None
 
-    x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
-    x = x.replace("[date]", datetime.date.today().isoformat())
-    x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
-    x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
+        prompt_no_style = self.prompt
+        for style in shared.prompt_styles.get_style_prompts(self.p.styles):
+            if len(style) > 0:
+                for part in style.split("{prompt}"):
+                    prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
 
-    # Apply [prompt] at last. Because it may contain any replacement word.^M
-    if prompt is not None:
-        x = x.replace("[prompt]", sanitize_filename_part(prompt))
-        if "[prompt_no_styles]" in x:
-            prompt_no_style = prompt
-            for style in shared.prompt_styles.get_style_prompts(p.styles):
-                if len(style) > 0:
-                    style_parts = [y for y in style.split("{prompt}")]
-                    for part in style_parts:
-                        prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
-            prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
-            x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
+                prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
 
-        x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
-        if "[prompt_words]" in x:
-            words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
-            if len(words) == 0:
-                words = ["empty"]
-            x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
+        return sanitize_filename_part(prompt_no_style, replace_spaces=False)
 
-    if cmd_opts.hide_ui_dir_config:
-        x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
+    def prompt_words(self):
+        words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
+        if len(words) == 0:
+            words = ["empty"]
+        return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
 
-    return x
+    def datetime(self, *args):
+        time_datetime = datetime.datetime.now()
+
+        time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
+        try:
+            time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
+        except pytz.exceptions.UnknownTimeZoneError as _:
+            time_zone = None
+
+        time_zone_time = time_datetime.astimezone(time_zone)
+        try:
+            formatted_time = time_zone_time.strftime(time_format)
+        except (ValueError, TypeError) as _:
+            formatted_time = time_zone_time.strftime(self.default_time_format)
+
+        return sanitize_filename_part(formatted_time, replace_spaces=False)
+
+    def apply(self, x):
+        res = ''
+
+        for m in re_pattern.finditer(x):
+            text, pattern = m.groups()
+            res += text
+
+            if pattern is None:
+                continue
+
+            pattern_args = []
+            while True:
+                m = re_pattern_arg.match(pattern)
+                if m is None:
+                    break
+
+                pattern, arg = m.groups()
+                pattern_args.insert(0, arg)
+
+            fun = self.replacements.get(pattern.lower())
+            if fun is not None:
+                try:
+                    replacement = fun(self, *pattern_args)
+                except Exception:
+                    replacement = None
+                    print(f"Error adding [{pattern}] to filename", file=sys.stderr)
+                    print(traceback.format_exc(), file=sys.stderr)
+
+                if replacement is not None:
+                    res += str(replacement)
+                    continue
+
+            res += f'[{pattern}]'
+
+        return res
 
 
 def get_next_sequence_number(path, basename):
@@ -354,7 +448,7 @@ def get_next_sequence_number(path, basename):
 
 
 def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
-    '''Save an image.
+    """Save an image.
 
     Args:
         image (`PIL.Image`):
@@ -363,7 +457,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
             The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
         basename (`str`):
             The base filename which will be applied to `filename pattern`.
-        seed, prompt, short_filename, 
+        seed, prompt, short_filename,
         extension (`str`):
             Image file extension, default is `png`.
         pngsectionname (`str`):
@@ -385,66 +479,94 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
             The full path of the saved imaged.
         txt_fullfn (`str` or None):
             If a text file is saved for this image, this will be its full path. Otherwise None.
-    '''
-    if short_filename or prompt is None or seed is None:
-        file_decoration = ""
-    elif opts.save_to_dirs:
-        file_decoration = opts.samples_filename_pattern or "[seed]"
-    else:
-        file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
-
-    if file_decoration != "":
-        file_decoration = "-" + file_decoration.lower()
-
-    file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
-
-    if extension == 'png' and opts.enable_pnginfo and info is not None:
-        pnginfo = PngImagePlugin.PngInfo()
-
-        if existing_info is not None:
-            for k, v in existing_info.items():
-                pnginfo.add_text(k, str(v))
-
-        pnginfo.add_text(pnginfo_section_name, info)
-    else:
-        pnginfo = None
+    """
+    namegen = FilenameGenerator(p, seed, prompt, image)
 
     if save_to_dirs is None:
         save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
 
     if save_to_dirs:
-        dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
+        dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
         path = os.path.join(path, dirname)
 
     os.makedirs(path, exist_ok=True)
 
     if forced_filename is None:
-        basecount = get_next_sequence_number(path, basename)
-        fullfn = "a.png"
-        fullfn_without_extension = "a"
-        for i in range(500):
-            fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
-            fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
-            fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
-            if not os.path.exists(fullfn):
-                break
+        if short_filename or seed is None:
+            file_decoration = ""
+        elif opts.save_to_dirs:
+            file_decoration = opts.samples_filename_pattern or "[seed]"
+        else:
+            file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
+
+        add_number = opts.save_images_add_number or file_decoration == ''
+
+        if file_decoration != "" and add_number:
+            file_decoration = "-" + file_decoration
+
+        file_decoration = namegen.apply(file_decoration) + suffix
+
+        if add_number:
+            basecount = get_next_sequence_number(path, basename)
+            fullfn = None
+            for i in range(500):
+                fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
+                fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
+                if not os.path.exists(fullfn):
+                    break
+        else:
+            fullfn = os.path.join(path, f"{file_decoration}.{extension}")
     else:
         fullfn = os.path.join(path, f"{forced_filename}.{extension}")
-        fullfn_without_extension = os.path.join(path, forced_filename)
 
-    def exif_bytes():
-        return piexif.dump({
-            "Exif": {
-                piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
-            },
-        })
+    pnginfo = existing_info or {}
+    if info is not None:
+        pnginfo[pnginfo_section_name] = info
 
-    if extension.lower() in ("jpg", "jpeg", "webp"):
-        image.save(fullfn, quality=opts.jpeg_quality)
-        if opts.enable_pnginfo and info is not None:
-            piexif.insert(exif_bytes(), fullfn)
-    else:
-        image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
+    params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
+    script_callbacks.before_image_saved_callback(params)
+
+    image = params.image
+    fullfn = params.filename
+    info = params.pnginfo.get(pnginfo_section_name, None)
+
+    def _atomically_save_image(image_to_save, filename_without_extension, extension):
+        # save image with .tmp extension to avoid race condition when another process detects new image in the directory
+        temp_file_path = filename_without_extension + ".tmp"
+        image_format = Image.registered_extensions()[extension]
+
+        if extension.lower() == '.png':
+            pnginfo_data = PngImagePlugin.PngInfo()
+            if opts.enable_pnginfo:
+                for k, v in params.pnginfo.items():
+                    pnginfo_data.add_text(k, str(v))
+
+            image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
+
+        elif extension.lower() in (".jpg", ".jpeg", ".webp"):
+            if image_to_save.mode == 'RGBA':
+                image_to_save = image_to_save.convert("RGB")
+
+            image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
+
+            if opts.enable_pnginfo and info is not None:
+                exif_bytes = piexif.dump({
+                    "Exif": {
+                        piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
+                    },
+                })
+
+                piexif.insert(exif_bytes, temp_file_path)
+        else:
+            image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
+
+        # atomically rename the file with correct extension
+        os.replace(temp_file_path, filename_without_extension + extension)
+
+    fullfn_without_extension, extension = os.path.splitext(params.filename)
+    _atomically_save_image(image, fullfn_without_extension, extension)
+
+    image.already_saved_as = fullfn
 
     target_side_length = 4000
     oversize = image.width > target_side_length or image.height > target_side_length
@@ -456,9 +578,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
         elif oversize:
             image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
 
-        image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality)
-        if opts.enable_pnginfo and info is not None:
-            piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
+        _atomically_save_image(image, fullfn_without_extension, ".jpg")
 
     if opts.save_txt and info is not None:
         txt_fullfn = f"{fullfn_without_extension}.txt"
@@ -467,13 +587,50 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
     else:
         txt_fullfn = None
 
+    script_callbacks.image_saved_callback(params)
+
     return fullfn, txt_fullfn
 
 
+def read_info_from_image(image):
+    items = image.info or {}
+
+    geninfo = items.pop('parameters', None)
+
+    if "exif" in items:
+        exif = piexif.load(items["exif"])
+        exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
+        try:
+            exif_comment = piexif.helper.UserComment.load(exif_comment)
+        except ValueError:
+            exif_comment = exif_comment.decode('utf8', errors="ignore")
+
+        items['exif comment'] = exif_comment
+        geninfo = exif_comment
+
+        for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+                      'loop', 'background', 'timestamp', 'duration']:
+            items.pop(field, None)
+
+    if items.get("Software", None) == "NovelAI":
+        try:
+            json_info = json.loads(items["Comment"])
+            sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
+
+            geninfo = f"""{items["Description"]}
+Negative prompt: {json_info["uc"]}
+Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
+        except Exception:
+            print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
+            print(traceback.format_exc(), file=sys.stderr)
+
+    return geninfo, items
+
+
 def image_data(data):
     try:
         image = Image.open(io.BytesIO(data))
-        textinfo = image.text["parameters"]
+        textinfo, _ = read_info_from_image(image)
         return textinfo, None
     except Exception:
         pass
@@ -487,3 +644,14 @@ def image_data(data):
         pass
 
     return '', None
+
+
+def flatten(img, bgcolor):
+    """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
+
+    if img.mode == "RGBA":
+        background = Image.new('RGBA', img.size, bgcolor)
+        background.paste(img, mask=img)
+        img = background
+
+    return img.convert('RGB')
diff --git a/modules/images_history.py b/modules/images_history.py
deleted file mode 100644
index 46b23e56..00000000
--- a/modules/images_history.py
+++ /dev/null
@@ -1,183 +0,0 @@
-import os
-import shutil
-import sys
-
-def traverse_all_files(output_dir, image_list, curr_dir=None):
-    curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
-    try:
-        f_list = os.listdir(curr_path)
-    except:
-        if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
-            image_list.append(curr_dir)
-        return image_list
-    for file in f_list:
-        file = file if curr_dir is None else os.path.join(curr_dir, file)
-        file_path = os.path.join(curr_path, file)
-        if file[-4:] == ".txt":
-            pass
-        elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
-            image_list.append(file)
-        else:
-            image_list = traverse_all_files(output_dir, image_list, file)
-    return image_list
-
-
-def get_recent_images(dir_name, page_index, step, image_index, tabname):
-    page_index = int(page_index)
-    image_list = []
-    if not os.path.exists(dir_name):
-        pass
-    elif os.path.isdir(dir_name):
-        image_list = traverse_all_files(dir_name, image_list)
-        image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
-    else:
-        print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
-    num = 48 if tabname != "extras" else 12
-    max_page_index = len(image_list) // num + 1
-    page_index = max_page_index if page_index == -1 else page_index + step
-    page_index = 1 if page_index < 1 else page_index
-    page_index = max_page_index if page_index > max_page_index else page_index
-    idx_frm = (page_index - 1) * num
-    image_list = image_list[idx_frm:idx_frm + num]
-    image_index = int(image_index)
-    if image_index < 0 or image_index > len(image_list) - 1:
-        current_file = None
-        hidden = None
-    else:
-        current_file = image_list[int(image_index)]
-        hidden = os.path.join(dir_name, current_file)
-    return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
-
-
-def first_page_click(dir_name, page_index, image_index, tabname):
-    return get_recent_images(dir_name, 1, 0, image_index, tabname)
-
-
-def end_page_click(dir_name, page_index, image_index, tabname):
-    return get_recent_images(dir_name, -1, 0, image_index, tabname)
-
-
-def prev_page_click(dir_name, page_index, image_index, tabname):
-    return get_recent_images(dir_name, page_index, -1, image_index, tabname)
-
-
-def next_page_click(dir_name, page_index, image_index, tabname):
-    return get_recent_images(dir_name, page_index, 1, image_index, tabname)
-
-
-def page_index_change(dir_name, page_index, image_index, tabname):
-    return get_recent_images(dir_name, page_index, 0, image_index, tabname)
-
-
-def show_image_info(num, image_path, filenames):
-    # print(f"select image {num}")
-    file = filenames[int(num)]
-    return file, num, os.path.join(image_path, file)
-
-
-def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
-    if name == "":
-        return filenames, delete_num
-    else:
-        delete_num = int(delete_num)
-        index = list(filenames).index(name)
-        i = 0
-        new_file_list = []
-        for name in filenames:
-            if i >= index and i < index + delete_num:
-                path = os.path.join(dir_name, name)
-                if os.path.exists(path):
-                    print(f"Delete file {path}")
-                    os.remove(path)
-                    txt_file = os.path.splitext(path)[0] + ".txt"
-                    if os.path.exists(txt_file):
-                        os.remove(txt_file)
-                else:
-                    print(f"Not exists file {path}")
-            else:
-                new_file_list.append(name)
-            i += 1
-    return new_file_list, 1
-
-
-def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
-    if opts.outdir_samples != "":
-        dir_name = opts.outdir_samples
-    elif tabname == "txt2img":
-        dir_name = opts.outdir_txt2img_samples
-    elif tabname == "img2img":
-        dir_name = opts.outdir_img2img_samples
-    elif tabname == "extras":
-        dir_name = opts.outdir_extras_samples
-    else:
-        return
-    with gr.Row():
-        renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
-        first_page = gr.Button('First Page')
-        prev_page = gr.Button('Prev Page')
-        page_index = gr.Number(value=1, label="Page Index")
-        next_page = gr.Button('Next Page')
-        end_page = gr.Button('End Page')
-    with gr.Row(elem_id=tabname + "_images_history"):
-        with gr.Row():
-            with gr.Column(scale=2):
-                history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
-                with gr.Row():
-                    delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
-                    delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
-            with gr.Column():
-                with gr.Row():
-                    pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
-                    pnginfo_send_to_img2img = gr.Button('Send to img2img')
-                with gr.Row():
-                    with gr.Column():
-                        img_file_info = gr.Textbox(label="Generate Info", interactive=False)
-                        img_file_name = gr.Textbox(label="File Name", interactive=False)
-                with gr.Row():
-                    # hiden items
-
-                    img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
-                    tabname_box = gr.Textbox(tabname, visible=False)
-                    image_index = gr.Textbox(value=-1, visible=False)
-                    set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
-                    filenames = gr.State()
-                    hidden = gr.Image(type="pil", visible=False)
-                    info1 = gr.Textbox(visible=False)
-                    info2 = gr.Textbox(visible=False)
-
-    # turn pages
-    gallery_inputs = [img_path, page_index, image_index, tabname_box]
-    gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
-
-    first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
-    next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
-    prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
-    end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
-    page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
-    renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
-    # page_index.change(page_index_change, inputs=[tabname_box, img_path,  page_index], outputs=[history_gallery, page_index])
-
-    # other funcitons
-    set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
-    img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
-    delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
-    hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
-
-    # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
-    switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
-    switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
-
-
-def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
-    with gr.Blocks(analytics_enabled=False) as images_history:
-        with gr.Tabs() as tabs:
-            with gr.Tab("txt2img history"):
-                with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
-                    show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
-            with gr.Tab("img2img history"):
-                with gr.Blocks(analytics_enabled=False) as images_history_img2img:
-                    show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
-            with gr.Tab("extras history"):
-                with gr.Blocks(analytics_enabled=False) as images_history_img2img:
-                    show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
-    return images_history
diff --git a/modules/img2img.py b/modules/img2img.py
index 24126774..ca58b5d8 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -4,9 +4,9 @@ import sys
 import traceback
 
 import numpy as np
-from PIL import Image, ImageOps, ImageChops
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
 
-from modules import devices
+from modules import devices, sd_samplers
 from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
 from modules.shared import opts, state
 import modules.shared as shared
@@ -19,7 +19,7 @@ import modules.scripts
 def process_batch(p, input_dir, output_dir, args):
     processing.fix_seed(p)
 
-    images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
+    images = shared.listfiles(input_dir)
 
     print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
 
@@ -39,6 +39,8 @@ def process_batch(p, input_dir, output_dir, args):
             break
 
         img = Image.open(image)
+        # Use the EXIF orientation of photos taken by smartphones.
+        img = ImageOps.exif_transpose(img)
         p.init_images = [img] * p.batch_size
 
         proc = modules.scripts.scripts_img2img.run(p, *args)
@@ -53,27 +55,48 @@ def process_batch(p, input_dir, output_dir, args):
                 filename = f"{left}-{n}{right}"
 
             if not save_normally:
+                os.makedirs(output_dir, exist_ok=True)
                 processed_image.save(os.path.join(output_dir, filename))
 
 
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
     is_inpaint = mode == 1
     is_batch = mode == 2
 
     if is_inpaint:
+        # Drawn mask
         if mask_mode == 0:
-            image = init_img_with_mask['image']
-            mask = init_img_with_mask['mask']
-            alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
-            mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
-            image = image.convert('RGB')
+            is_mask_sketch = isinstance(init_img_with_mask, dict)
+            is_mask_paint = not is_mask_sketch
+            if is_mask_sketch:
+                # Sketch: mask iff. not transparent
+                image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
+                alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
+                mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
+            else:
+                # Color-sketch: mask iff. painted over
+                image = init_img_with_mask
+                orig = init_img_with_mask_orig or init_img_with_mask
+                pred = np.any(np.array(image) != np.array(orig), axis=-1)
+                mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
+                mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
+                blur = ImageFilter.GaussianBlur(mask_blur)
+                image = Image.composite(image.filter(blur), orig, mask.filter(blur))
+
+            image = image.convert("RGB")
+        # Uploaded mask
         else:
             image = init_img_inpaint
             mask = init_mask_inpaint
+    # No mask
     else:
         image = init_img
         mask = None
 
+    # Use the EXIF orientation of photos taken by smartphones.
+    if image is not None:
+        image = ImageOps.exif_transpose(image)
+
     assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
 
     p = StableDiffusionProcessingImg2Img(
@@ -89,7 +112,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
         seed_resize_from_h=seed_resize_from_h,
         seed_resize_from_w=seed_resize_from_w,
         seed_enable_extras=seed_enable_extras,
-        sampler_index=sampler_index,
+        sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
         batch_size=batch_size,
         n_iter=n_iter,
         steps=steps,
@@ -109,6 +132,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
         inpainting_mask_invert=inpainting_mask_invert,
     )
 
+    p.scripts = modules.scripts.scripts_txt2img
+    p.script_args = args
+
     if shared.cmd_opts.enable_console_prompts:
         print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
 
@@ -125,6 +151,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
         if processed is None:
             processed = process_images(p)
 
+    p.close()
+
     shared.total_tqdm.clear()
 
     generation_info_js = processed.js()
@@ -134,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
     if opts.do_not_show_images:
         processed.images = []
 
-    return processed.images, generation_info_js, plaintext_to_html(processed.info)
+    return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/import_hook.py b/modules/import_hook.py
new file mode 100644
index 00000000..28c67dfa
--- /dev/null
+++ b/modules/import_hook.py
@@ -0,0 +1,5 @@
+import sys
+
+# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
+if "--xformers" not in "".join(sys.argv):
+    sys.modules["xformers"] = None
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 64b91eb4..738d8ff7 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -1,4 +1,3 @@
-import contextlib
 import os
 import sys
 import traceback
@@ -11,10 +10,9 @@ from torchvision import transforms
 from torchvision.transforms.functional import InterpolationMode
 
 import modules.shared as shared
-from modules import devices, paths, lowvram
+from modules import devices, paths, lowvram, modelloader
 
 blip_image_eval_size = 384
-blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
 clip_model_name = 'ViT-L/14'
 
 Category = namedtuple("Category", ["name", "topn", "items"])
@@ -28,9 +26,11 @@ class InterrogateModels:
     clip_preprocess = None
     categories = None
     dtype = None
+    running_on_cpu = None
 
     def __init__(self, content_dir):
         self.categories = []
+        self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
 
         if os.path.exists(content_dir):
             for filename in os.listdir(content_dir):
@@ -45,7 +45,14 @@ class InterrogateModels:
     def load_blip_model(self):
         import models.blip
 
-        blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
+        files = modelloader.load_models(
+            model_path=os.path.join(paths.models_path, "BLIP"),
+            model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
+            ext_filter=[".pth"],
+            download_name='model_base_caption_capfilt_large.pth',
+        )
+
+        blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
         blip_model.eval()
 
         return blip_model
@@ -53,7 +60,11 @@ class InterrogateModels:
     def load_clip_model(self):
         import clip
 
-        model, preprocess = clip.load(clip_model_name)
+        if self.running_on_cpu:
+            model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
+        else:
+            model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
+
         model.eval()
         model = model.to(devices.device_interrogate)
 
@@ -62,14 +73,14 @@ class InterrogateModels:
     def load(self):
         if self.blip_model is None:
             self.blip_model = self.load_blip_model()
-            if not shared.cmd_opts.no_half:
+            if not shared.cmd_opts.no_half and not self.running_on_cpu:
                 self.blip_model = self.blip_model.half()
 
         self.blip_model = self.blip_model.to(devices.device_interrogate)
 
         if self.clip_model is None:
             self.clip_model, self.clip_preprocess = self.load_clip_model()
-            if not shared.cmd_opts.no_half:
+            if not shared.cmd_opts.no_half and not self.running_on_cpu:
                 self.clip_model = self.clip_model.half()
 
         self.clip_model = self.clip_model.to(devices.device_interrogate)
@@ -124,8 +135,9 @@ class InterrogateModels:
         return caption[0]
 
     def interrogate(self, pil_image):
-        res = None
-
+        res = ""
+        shared.state.begin()
+        shared.state.job = 'interrogate'
         try:
 
             if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -142,8 +154,7 @@ class InterrogateModels:
 
             clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
 
-            precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
-            with torch.no_grad(), precision_scope("cuda"):
+            with torch.no_grad(), devices.autocast():
                 image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
 
                 image_features /= image_features.norm(dim=-1, keepdim=True)
@@ -162,10 +173,11 @@ class InterrogateModels:
                             res += ", " + match
 
         except Exception:
-            print(f"Error interrogating", file=sys.stderr)
+            print("Error interrogating", file=sys.stderr)
             print(traceback.format_exc(), file=sys.stderr)
             res += "<error>"
 
         self.unload()
+        shared.state.end()
 
         return res
diff --git a/modules/localization.py b/modules/localization.py
index b1810cda..f6a6f2fb 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -3,6 +3,7 @@ import os
 import sys
 import traceback
 
+
 localizations = {}
 
 
@@ -16,6 +17,11 @@ def list_localizations(dirname):
 
         localizations[fn] = os.path.join(dirname, file)
 
+    from modules import scripts
+    for file in scripts.list_scripts("localizations", ".json"):
+        fn, ext = os.path.splitext(file.filename)
+        localizations[fn] = file.path
+
 
 def localization_js(current_localization_name):
     fn = localizations.get(current_localization_name, None)
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 7eba1349..042a0254 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,9 +1,8 @@
 import torch
-from modules.devices import get_optimal_device
+from modules import devices
 
 module_in_gpu = None
 cpu = torch.device("cpu")
-device = gpu = get_optimal_device()
 
 
 def send_everything_to_cpu():
@@ -33,34 +32,49 @@ def setup_for_low_vram(sd_model, use_medvram):
         if module_in_gpu is not None:
             module_in_gpu.to(cpu)
 
-        module.to(gpu)
+        module.to(devices.device)
         module_in_gpu = module
 
     # see below for register_forward_pre_hook;
     # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
     # useless here, and we just replace those methods
-    def first_stage_model_encode_wrap(self, encoder, x):
-        send_me_to_gpu(self, None)
-        return encoder(x)
 
-    def first_stage_model_decode_wrap(self, decoder, z):
-        send_me_to_gpu(self, None)
-        return decoder(z)
+    first_stage_model = sd_model.first_stage_model
+    first_stage_model_encode = sd_model.first_stage_model.encode
+    first_stage_model_decode = sd_model.first_stage_model.decode
 
-    # remove three big modules, cond, first_stage, and unet from the model and then
+    def first_stage_model_encode_wrap(x):
+        send_me_to_gpu(first_stage_model, None)
+        return first_stage_model_encode(x)
+
+    def first_stage_model_decode_wrap(z):
+        send_me_to_gpu(first_stage_model, None)
+        return first_stage_model_decode(z)
+
+    # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
+    if hasattr(sd_model.cond_stage_model, 'model'):
+        sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
+
+    # remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
     # send the model to GPU. Then put modules back. the modules will be in CPU.
-    stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
-    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
-    sd_model.to(device)
-    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
+    stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
+    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
+    sd_model.to(devices.device)
+    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
 
-    # register hooks for those the first two models
+    # register hooks for those the first three models
     sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
     sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
-    sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
-    sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
+    sd_model.first_stage_model.encode = first_stage_model_encode_wrap
+    sd_model.first_stage_model.decode = first_stage_model_decode_wrap
+    if sd_model.depth_model:
+        sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
     parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
 
+    if hasattr(sd_model.cond_stage_model, 'model'):
+        sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
+        del sd_model.cond_stage_model.transformer
+
     if use_medvram:
         sd_model.model.register_forward_pre_hook(send_me_to_gpu)
     else:
@@ -70,7 +84,7 @@ def setup_for_low_vram(sd_model, use_medvram):
         # so that only one of them is in GPU at a time
         stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
         diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
-        sd_model.model.to(device)
+        sd_model.model.to(devices.device)
         diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
 
         # install hooks for bits of third model
diff --git a/modules/masking.py b/modules/masking.py
index fd8d9241..a5c4d2da 100644
--- a/modules/masking.py
+++ b/modules/masking.py
@@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w
     ratio_processing = processing_width / processing_height
 
     if ratio_crop_region > ratio_processing:
-        desired_height = (x2 - x1) * ratio_processing
+        desired_height = (x2 - x1) / ratio_processing
         desired_height_diff = int(desired_height - (y2-y1))
         y1 -= desired_height_diff//2
         y2 += desired_height_diff - desired_height_diff//2
diff --git a/modules/memmon.py b/modules/memmon.py
index 9fb9b687..a7060f58 100644
--- a/modules/memmon.py
+++ b/modules/memmon.py
@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
     def read(self):
         if not self.disabled:
             free, total = torch.cuda.mem_get_info()
+            self.data["free"] = free
             self.data["total"] = total
 
             torch_stats = torch.cuda.memory_stats(self.device)
+            self.data["active"] = torch_stats["active.all.current"]
             self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
+            self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
             self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
             self.data["system_peak"] = total - self.data["min_free"]
 
diff --git a/modules/modelloader.py b/modules/modelloader.py
index b0f2f33d..6a1a7ac8 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -82,9 +82,13 @@ def cleanup_models():
     src_path = models_path
     dest_path = os.path.join(models_path, "Stable-diffusion")
     move_files(src_path, dest_path, ".ckpt")
+    move_files(src_path, dest_path, ".safetensors")
     src_path = os.path.join(root_path, "ESRGAN")
     dest_path = os.path.join(models_path, "ESRGAN")
     move_files(src_path, dest_path)
+    src_path = os.path.join(models_path, "BSRGAN")
+    dest_path = os.path.join(models_path, "ESRGAN")
+    move_files(src_path, dest_path, ".pth")
     src_path = os.path.join(root_path, "gfpgan")
     dest_path = os.path.join(models_path, "GFPGAN")
     move_files(src_path, dest_path)
@@ -119,11 +123,27 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
         pass
 
 
+builtin_upscaler_classes = []
+forbidden_upscaler_classes = set()
+
+
+def list_builtin_upscalers():
+    load_upscalers()
+
+    builtin_upscaler_classes.clear()
+    builtin_upscaler_classes.extend(Upscaler.__subclasses__())
+
+
+def forbid_loaded_nonbuiltin_upscalers():
+    for cls in Upscaler.__subclasses__():
+        if cls not in builtin_upscaler_classes:
+            forbidden_upscaler_classes.add(cls)
+
+
 def load_upscalers():
-    sd = shared.script_path
     # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
     # so we'll try to import any _model.py files before looking in __subclasses__
-    modules_dir = os.path.join(sd, "modules")
+    modules_dir = os.path.join(shared.script_path, "modules")
     for file in os.listdir(modules_dir):
         if "_model.py" in file:
             model_name = file.replace("_model.py", "")
@@ -132,22 +152,16 @@ def load_upscalers():
                 importlib.import_module(full_model)
             except:
                 pass
+
     datas = []
-    c_o = vars(shared.cmd_opts)
+    commandline_options = vars(shared.cmd_opts)
     for cls in Upscaler.__subclasses__():
+        if cls in forbidden_upscaler_classes:
+            continue
+
         name = cls.__name__
-        module_name = cls.__module__
-        module = importlib.import_module(module_name)
-        class_ = getattr(module, name)
         cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
-        opt_string = None
-        try:
-            if cmd_name in c_o:
-                opt_string = c_o[cmd_name]
-        except:
-            pass
-        scaler = class_(opt_string)
-        for child in scaler.scalers:
-            datas.append(child)
+        scaler = cls(commandline_options.get(cmd_name, None))
+        datas += scaler.scalers
 
     shared.sd_upscalers = datas
diff --git a/modules/ngrok.py b/modules/ngrok.py
index 5c5f349a..3df2c06b 100644
--- a/modules/ngrok.py
+++ b/modules/ngrok.py
@@ -1,14 +1,23 @@
 from pyngrok import ngrok, conf, exception
 
-
 def connect(token, port, region):
-    if token == None:
+    account = None
+    if token is None:
         token = 'None'
+    else:
+        if ':' in token:
+            # token = authtoken:username:password
+            account = token.split(':')[1] + ':' + token.split(':')[-1]
+            token = token.split(':')[0]
+
     config = conf.PyngrokConfig(
         auth_token=token, region=region
     )
     try:
-        public_url = ngrok.connect(port, pyngrok_config=config).public_url
+        if account is None:
+            public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
+        else:
+            public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
     except exception.PyngrokNgrokError:
         print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
               f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
diff --git a/modules/paths.py b/modules/paths.py
index 1e7a2fbc..4dd03a35 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -9,7 +9,7 @@ sys.path.insert(0, script_path)
 
 # search for directory of stable diffusion in following places
 sd_path = None
-possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
+possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
 for possible_sd_path in possible_sd_paths:
     if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
         sd_path = os.path.abspath(possible_sd_path)
diff --git a/modules/processing.py b/modules/processing.py
index bcb0c32c..c7264aff 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -2,6 +2,7 @@ import json
 import math
 import os
 import sys
+import warnings
 
 import torch
 import numpy as np
@@ -12,15 +13,21 @@ from skimage import exposure
 from typing import Any, Dict, List, Optional
 
 import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
 from modules.sd_hijack import model_hijack
 from modules.shared import opts, cmd_opts, state
 import modules.shared as shared
 import modules.face_restoration
 import modules.images as images
 import modules.styles
+import modules.sd_models as sd_models
+import modules.sd_vae as sd_vae
 import logging
+from ldm.data.util import AddMiDaS
+from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
 
+from einops import repeat, rearrange
+from blendmodes.blend import blendLayers, BlendType
 
 # some of those options should not be changed at all because they would break the model, so I removed them from options.
 opt_C = 4
@@ -33,34 +40,68 @@ def setup_color_correction(image):
     return correction_target
 
 
-def apply_color_correction(correction, image):
+def apply_color_correction(correction, original_image):
     logging.info("Applying color correction.")
     image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
         cv2.cvtColor(
-            np.asarray(image),
+            np.asarray(original_image),
             cv2.COLOR_RGB2LAB
         ),
         correction,
         channel_axis=2
     ), cv2.COLOR_LAB2RGB).astype("uint8"))
+    
+    image = blendLayers(image, original_image, BlendType.LUMINOSITY)
+    
+    return image
+
+
+def apply_overlay(image, paste_loc, index, overlays):
+    if overlays is None or index >= len(overlays):
+        return image
+
+    overlay = overlays[index]
+
+    if paste_loc is not None:
+        x, y, w, h = paste_loc
+        base_image = Image.new('RGBA', (overlay.width, overlay.height))
+        image = images.resize_image(1, image, w, h)
+        base_image.paste(image, (x, y))
+        image = base_image
+
+    image = image.convert('RGBA')
+    image.alpha_composite(overlay)
+    image = image.convert('RGB')
 
     return image
 
 
-def get_correct_sampler(p):
-    if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
-        return sd_samplers.samplers
-    elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
-        return sd_samplers.samplers_for_img2img
-    elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
-        return sd_samplers.samplers
+def txt2img_image_conditioning(sd_model, x, width, height):
+    if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
+        # Dummy zero conditioning if we're not using inpainting model.
+        # Still takes up a bit of memory, but no encoder call.
+        # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+        return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+    # The "masked-image" in this case will just be all zeros since the entire image is masked.
+    image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+    image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+
+    # Add the fake full 1s mask to the first dimension.
+    image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+    image_conditioning = image_conditioning.to(x.dtype)
+
+    return image_conditioning
+
 
 class StableDiffusionProcessing():
     """
     The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
-    
     """
-    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
+    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
+        if sampler_index is not None:
+            print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
+
         self.sd_model = sd_model
         self.outpath_samples: str = outpath_samples
         self.outpath_grids: str = outpath_grids
@@ -73,7 +114,7 @@ class StableDiffusionProcessing():
         self.subseed_strength: float = subseed_strength
         self.seed_resize_from_h: int = seed_resize_from_h
         self.seed_resize_from_w: int = seed_resize_from_w
-        self.sampler_index: int = sampler_index
+        self.sampler_name: str = sampler_name
         self.batch_size: int = batch_size
         self.n_iter: int = n_iter
         self.steps: int = steps
@@ -90,13 +131,16 @@ class StableDiffusionProcessing():
         self.do_not_reload_embeddings = do_not_reload_embeddings
         self.paste_to = None
         self.color_corrections = None
-        self.denoising_strength: float = 0
+        self.denoising_strength: float = denoising_strength
         self.sampler_noise_scheduler_override = None
-        self.ddim_discretize = opts.ddim_discretize
+        self.ddim_discretize = ddim_discretize or opts.ddim_discretize
         self.s_churn = s_churn or opts.s_churn
         self.s_tmin = s_tmin or opts.s_tmin
         self.s_tmax = s_tmax or float('inf')  # not representable as a standard ui option
         self.s_noise = s_noise or opts.s_noise
+        self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
+        self.override_settings_restore_afterwards = override_settings_restore_afterwards
+        self.is_using_inpainting_conditioning = False
 
         if not seed_enable_extras:
             self.subseed = -1
@@ -104,16 +148,100 @@ class StableDiffusionProcessing():
             self.seed_resize_from_h = 0
             self.seed_resize_from_w = 0
 
+        self.scripts = None
+        self.script_args = None
+        self.all_prompts = None
+        self.all_negative_prompts = None
+        self.all_seeds = None
+        self.all_subseeds = None
+        self.iteration = 0
+
+    def txt2img_image_conditioning(self, x, width=None, height=None):
+        self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+
+        return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
+
+    def depth2img_image_conditioning(self, source_image):
+        # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
+        transformer = AddMiDaS(model_type="dpt_hybrid")
+        transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
+        midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
+        midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
+
+        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+        conditioning = torch.nn.functional.interpolate(
+            self.sd_model.depth_model(midas_in),
+            size=conditioning_image.shape[2:],
+            mode="bicubic",
+            align_corners=False,
+        )
+
+        (depth_min, depth_max) = torch.aminmax(conditioning)
+        conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
+        return conditioning
+
+    def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
+        self.is_using_inpainting_conditioning = True
+
+        # Handle the different mask inputs
+        if image_mask is not None:
+            if torch.is_tensor(image_mask):
+                conditioning_mask = image_mask
+            else:
+                conditioning_mask = np.array(image_mask.convert("L"))
+                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+                # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+                conditioning_mask = torch.round(conditioning_mask)
+        else:
+            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
+
+        # Create another latent image, this time with a masked version of the original input.
+        # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
+        conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
+        conditioning_image = torch.lerp(
+            source_image,
+            source_image * (1.0 - conditioning_mask),
+            getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
+        )
+
+        # Encode the new masked image using first stage of network.
+        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+        # Create the concatenated conditioning tensor to be fed to `c_concat`
+        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
+        conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+        image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+        image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
+
+        return image_conditioning
+
+    def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+        # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
+        # identify itself with a field common to all models. The conditioning_key is also hybrid.
+        if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
+            return self.depth2img_image_conditioning(source_image)
+
+        if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+
+        # Dummy zero conditioning if we're not using inpainting or depth model.
+        return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
 
     def init(self, all_prompts, all_seeds, all_subseeds):
         pass
 
-    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
         raise NotImplementedError()
 
+    def close(self):
+        self.sd_model = None
+        self.sampler = None
+
 
 class Processed:
-    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
+    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
         self.images = images_list
         self.prompt = p.prompt
         self.negative_prompt = p.negative_prompt
@@ -121,10 +249,10 @@ class Processed:
         self.subseed = subseed
         self.subseed_strength = p.subseed_strength
         self.info = info
+        self.comments = comments
         self.width = p.width
         self.height = p.height
-        self.sampler_index = p.sampler_index
-        self.sampler = sd_samplers.samplers[p.sampler_index].name
+        self.sampler_name = p.sampler_name
         self.cfg_scale = p.cfg_scale
         self.steps = p.steps
         self.batch_size = p.batch_size
@@ -151,17 +279,20 @@ class Processed:
         self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
         self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
         self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
+        self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
 
-        self.all_prompts = all_prompts or [self.prompt]
-        self.all_seeds = all_seeds or [self.seed]
-        self.all_subseeds = all_subseeds or [self.subseed]
+        self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
+        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
+        self.all_seeds = all_seeds or p.all_seeds or [self.seed]
+        self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
         self.infotexts = infotexts or [info]
 
     def js(self):
         obj = {
-            "prompt": self.prompt,
+            "prompt": self.all_prompts[0],
             "all_prompts": self.all_prompts,
-            "negative_prompt": self.negative_prompt,
+            "negative_prompt": self.all_negative_prompts[0],
+            "all_negative_prompts": self.all_negative_prompts,
             "seed": self.seed,
             "all_seeds": self.all_seeds,
             "subseed": self.subseed,
@@ -169,8 +300,7 @@ class Processed:
             "subseed_strength": self.subseed_strength,
             "width": self.width,
             "height": self.height,
-            "sampler_index": self.sampler_index,
-            "sampler": self.sampler,
+            "sampler_name": self.sampler_name,
             "cfg_scale": self.cfg_scale,
             "steps": self.steps,
             "batch_size": self.batch_size,
@@ -186,11 +316,12 @@ class Processed:
             "styles": self.styles,
             "job_timestamp": self.job_timestamp,
             "clip_skip": self.clip_skip,
+            "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
         }
 
         return json.dumps(obj)
 
-    def infotext(self,  p: StableDiffusionProcessing, index):
+    def infotext(self, p: StableDiffusionProcessing, index):
         return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
 
 
@@ -210,13 +341,14 @@ def slerp(val, low, high):
 
 
 def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
+    eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
     xs = []
 
     # if we have multiple seeds, this means we are working with batch size>1; this then
     # enables the generation of additional tensors with noise that the sampler will use during its processing.
     # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
     # produce the same images as with two batches [100], [101].
-    if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
+    if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
         sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
     else:
         sampler_noises = None
@@ -256,8 +388,8 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
         if sampler_noises is not None:
             cnt = p.sampler.number_of_needed_noises(p)
 
-            if opts.eta_noise_seed_delta > 0:
-                torch.manual_seed(seed + opts.eta_noise_seed_delta)
+            if eta_noise_seed_delta > 0:
+                torch.manual_seed(seed + eta_noise_seed_delta)
 
             for j in range(cnt):
                 sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -297,20 +429,23 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
 
     generation_params = {
         "Steps": p.steps,
-        "Sampler": get_correct_sampler(p)[p.sampler_index].name,
+        "Sampler": p.sampler_name,
         "CFG scale": p.cfg_scale,
         "Seed": all_seeds[index],
         "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
         "Size": f"{p.width}x{p.height}",
         "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
         "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
-        "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
+        "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+        "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
+        "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
         "Batch size": (None if p.batch_size < 2 else p.batch_size),
         "Batch pos": (None if p.batch_size < 2 else position_in_batch),
         "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
         "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
         "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
         "Denoising strength": getattr(p, 'denoising_strength', None),
+        "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
         "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
         "Clip skip": None if clip_skip <= 1 else clip_skip,
         "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
@@ -318,14 +453,38 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
 
     generation_params.update(p.extra_generation_params)
 
-    generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
+    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
 
-    negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
+    negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
 
     return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
 
 
 def process_images(p: StableDiffusionProcessing) -> Processed:
+    stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
+
+    try:
+        for k, v in p.override_settings.items():
+            setattr(opts, k, v)
+            if k == 'sd_hypernetwork': shared.reload_hypernetworks()  # make onchange call for changing hypernet
+            if k == 'sd_model_checkpoint': sd_models.reload_model_weights()  # make onchange call for changing SD model
+            if k == 'sd_vae': sd_vae.reload_vae_weights()  # make onchange call for changing VAE
+
+        res = process_images_inner(p)
+
+    finally:
+        # restore opts to original state
+        if p.override_settings_restore_afterwards:
+            for k, v in stored_opts.items():
+                setattr(opts, k, v)
+                if k == 'sd_hypernetwork': shared.reload_hypernetworks()
+                if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
+                if k == 'sd_vae': sd_vae.reload_vae_weights()
+
+    return res
+
+
+def process_images_inner(p: StableDiffusionProcessing) -> Processed:
     """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
 
     if type(p.prompt) == list:
@@ -333,10 +492,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
     else:
         assert p.prompt is not None
 
-    with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
-        processed = Processed(p, [], p.seed, "")
-        file.write(processed.infotext(p, 0))
-
     devices.torch_gc()
 
     seed = get_fixed_seed(p.seed)
@@ -347,57 +502,71 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
 
     comments = {}
 
-    shared.prompt_styles.apply_styles(p)
-
     if type(p.prompt) == list:
-        all_prompts = p.prompt
+        p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
     else:
-        all_prompts = p.batch_size * p.n_iter * [p.prompt]
+        p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
+
+    if type(p.negative_prompt) == list:
+        p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
+    else:
+        p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
 
     if type(seed) == list:
-        all_seeds = seed
+        p.all_seeds = seed
     else:
-        all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
+        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
 
     if type(subseed) == list:
-        all_subseeds = subseed
+        p.all_subseeds = subseed
     else:
-        all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
+        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
 
     def infotext(iteration=0, position_in_batch=0):
-        return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
+        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
+
+    with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+        processed = Processed(p, [], p.seed, "")
+        file.write(processed.infotext(p, 0))
 
     if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
         model_hijack.embedding_db.load_textual_inversion_embeddings()
 
+    if p.scripts is not None:
+        p.scripts.process(p)
+
     infotexts = []
     output_images = []
 
     with torch.no_grad(), p.sd_model.ema_scope():
         with devices.autocast():
-            p.init(all_prompts, all_seeds, all_subseeds)
+            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
 
         if state.job_count == -1:
             state.job_count = p.n_iter
 
         for n in range(p.n_iter):
+            p.iteration = n
+
             if state.skipped:
                 state.skipped = False
-            
+
             if state.interrupted:
                 break
 
-            prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
-            seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
-            subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+            prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+            negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+            seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+            subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
 
-            if (len(prompts) == 0):
+            if len(prompts) == 0:
                 break
 
-            #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
-            #c = p.sd_model.get_learned_conditioning(prompts)
+            if p.scripts is not None:
+                p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+
             with devices.autocast():
-                uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
+                uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
                 c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
 
             if len(model_hijack.comments) > 0:
@@ -408,10 +577,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
                 shared.state.job = f"Batch {n+1} out of {p.n_iter}"
 
             with devices.autocast():
-                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
 
-            samples_ddim = samples_ddim.to(devices.dtype_vae)
-            x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
+            x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
+            x_samples_ddim = torch.stack(x_samples_ddim).float()
             x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
             del samples_ddim
@@ -421,9 +590,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
 
             devices.torch_gc()
 
-            if opts.filter_nsfw:
-                import modules.safety as safety
-                x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
+            if p.scripts is not None:
+                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
 
             for i, x_sample in enumerate(x_samples_ddim):
                 x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@@ -442,22 +610,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
 
                 if p.color_corrections is not None and i < len(p.color_corrections):
                     if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
-                        images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
+                        images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
                     image = apply_color_correction(p.color_corrections[i], image)
 
-                if p.overlay_images is not None and i < len(p.overlay_images):
-                    overlay = p.overlay_images[i]
-
-                    if p.paste_to is not None:
-                        x, y, w, h = p.paste_to
-                        base_image = Image.new('RGBA', (overlay.width, overlay.height))
-                        image = images.resize_image(1, image, w, h)
-                        base_image.paste(image, (x, y))
-                        image = base_image
-
-                    image = image.convert('RGBA')
-                    image.alpha_composite(overlay)
-                    image = image.convert('RGB')
+                image = apply_overlay(image, p.paste_to, i, p.overlay_images)
 
                 if opts.samples_save and not p.do_not_save_samples:
                     images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
@@ -468,7 +625,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
                     image.info["parameters"] = text
                 output_images.append(image)
 
-            del x_samples_ddim 
+            del x_samples_ddim
 
             devices.torch_gc()
 
@@ -490,23 +647,33 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
                 index_of_first_image = 1
 
             if opts.grid_save:
-                images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
 
     devices.torch_gc()
-    return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+    res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+    if p.scripts is not None:
+        p.scripts.postprocess(p, res)
+
+    return res
 
 
 class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
     sampler = None
 
-    def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
+    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
         super().__init__(**kwargs)
         self.enable_hr = enable_hr
         self.denoising_strength = denoising_strength
-        self.firstphase_width = firstphase_width
-        self.firstphase_height = firstphase_height
-        self.truncate_x = 0
-        self.truncate_y = 0
+        self.hr_scale = hr_scale
+        self.hr_upscaler = hr_upscaler
+
+        if firstphase_width != 0 or firstphase_height != 0:
+            print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
+            self.hr_scale = self.width / firstphase_width
+            self.width = firstphase_width
+            self.height = firstphase_height
 
     def init(self, all_prompts, all_seeds, all_subseeds):
         if self.enable_hr:
@@ -515,48 +682,50 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             else:
                 state.job_count = state.job_count * 2
 
-            if self.firstphase_width == 0 or self.firstphase_height == 0:
-                desired_pixel_count = 512 * 512
-                actual_pixel_count = self.width * self.height
-                scale = math.sqrt(desired_pixel_count / actual_pixel_count)
-                self.firstphase_width = math.ceil(scale * self.width / 64) * 64
-                self.firstphase_height = math.ceil(scale * self.height / 64) * 64
-                firstphase_width_truncated = int(scale * self.width)
-                firstphase_height_truncated = int(scale * self.height)
+            self.extra_generation_params["Hires upscale"] = self.hr_scale
+            if self.hr_upscaler is not None:
+                self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
 
-            else:
+    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
+        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
 
-                width_ratio = self.width / self.firstphase_width
-                height_ratio = self.height / self.firstphase_height
+        latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+        if self.enable_hr and latent_scale_mode is None:
+            assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
 
-                if width_ratio > height_ratio:
-                    firstphase_width_truncated = self.firstphase_width
-                    firstphase_height_truncated = self.firstphase_width * self.height / self.width
-                else:
-                    firstphase_width_truncated = self.firstphase_height * self.width / self.height
-                    firstphase_height_truncated = self.firstphase_height
-
-            self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
-            self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
-            self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
-
-
-    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
-        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
+        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
 
         if not self.enable_hr:
-            x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
-            samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
             return samples
 
-        x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
-        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+        target_width = int(self.width * self.hr_scale)
+        target_height = int(self.height * self.hr_scale)
 
-        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+        def save_intermediate(image, index):
+            """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
 
-        if opts.use_scale_latent_for_hires_fix:
-            samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+            if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
+                return
 
+            if not isinstance(image, Image.Image):
+                image = sd_samplers.sample_to_image(image, index, approximation=0)
+
+            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
+            images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
+
+        if latent_scale_mode is not None:
+            for i in range(samples.shape[0]):
+                save_intermediate(samples, i)
+
+            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
+
+            # Avoid making the inpainting conditioning unless necessary as
+            # this does need some extra compute to decode / encode the image again.
+            if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
+                image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
+            else:
+                image_conditioning = self.txt2img_image_conditioning(samples)
         else:
             decoded_samples = decode_first_stage(self.sd_model, samples)
             lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@@ -566,7 +735,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                 x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                 x_sample = x_sample.astype(np.uint8)
                 image = Image.fromarray(x_sample)
-                image = images.resize_image(0, image, self.width, self.height)
+
+                save_intermediate(image, i)
+
+                image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
                 image = np.array(image).astype(np.float32) / 255.0
                 image = np.moveaxis(image, 2, 0)
                 batch_images.append(image)
@@ -577,17 +749,19 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
             samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
 
+            image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
+
         shared.state.nextjob()
 
-        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
+        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
 
-        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+        noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
 
         # GC now before running the next img2img to prevent running out of memory
         x = None
         devices.torch_gc()
 
-        samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
+        samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
 
         return samples
 
@@ -595,7 +769,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
     sampler = None
 
-    def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
+    def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
         super().__init__(**kwargs)
 
         self.init_images = init_images
@@ -603,7 +777,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
         self.denoising_strength: float = denoising_strength
         self.init_latent = None
         self.image_mask = mask
-        #self.image_unblurred_mask = None
         self.latent_mask = None
         self.mask_for_overlay = None
         self.mask_blur = mask_blur
@@ -611,65 +784,68 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
         self.inpaint_full_res = inpaint_full_res
         self.inpaint_full_res_padding = inpaint_full_res_padding
         self.inpainting_mask_invert = inpainting_mask_invert
+        self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
         self.mask = None
         self.nmask = None
+        self.image_conditioning = None
 
     def init(self, all_prompts, all_seeds, all_subseeds):
-        self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
+        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
         crop_region = None
 
-        if self.image_mask is not None:
-            self.image_mask = self.image_mask.convert('L')
+        image_mask = self.image_mask
+
+        if image_mask is not None:
+            image_mask = image_mask.convert('L')
 
             if self.inpainting_mask_invert:
-                self.image_mask = ImageOps.invert(self.image_mask)
-
-            #self.image_unblurred_mask = self.image_mask
+                image_mask = ImageOps.invert(image_mask)
 
             if self.mask_blur > 0:
-                self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
+                image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
 
             if self.inpaint_full_res:
-                self.mask_for_overlay = self.image_mask
-                mask = self.image_mask.convert('L')
+                self.mask_for_overlay = image_mask
+                mask = image_mask.convert('L')
                 crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
                 crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
                 x1, y1, x2, y2 = crop_region
 
                 mask = mask.crop(crop_region)
-                self.image_mask = images.resize_image(2, mask, self.width, self.height)
+                image_mask = images.resize_image(2, mask, self.width, self.height)
                 self.paste_to = (x1, y1, x2-x1, y2-y1)
             else:
-                self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
-                np_mask = np.array(self.image_mask)
+                image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
+                np_mask = np.array(image_mask)
                 np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
                 self.mask_for_overlay = Image.fromarray(np_mask)
 
             self.overlay_images = []
 
-        latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
+        latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
 
         add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
         if add_color_corrections:
             self.color_corrections = []
         imgs = []
         for img in self.init_images:
-            image = img.convert("RGB")
+            image = images.flatten(img, opts.img2img_background_color)
 
-            if crop_region is None:
+            if crop_region is None and self.resize_mode != 3:
                 image = images.resize_image(self.resize_mode, image, self.width, self.height)
 
-            if self.image_mask is not None:
+            if image_mask is not None:
                 image_masked = Image.new('RGBa', (image.width, image.height))
                 image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
 
                 self.overlay_images.append(image_masked.convert('RGBA'))
 
+            # crop_region is not None if we are doing inpaint full res
             if crop_region is not None:
                 image = image.crop(crop_region)
                 image = images.resize_image(2, image, self.width, self.height)
 
-            if self.image_mask is not None:
+            if image_mask is not None:
                 if self.inpainting_fill != 1:
                     image = masking.fill(image, latent_mask)
 
@@ -685,6 +861,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
             if self.overlay_images is not None:
                 self.overlay_images = self.overlay_images * self.batch_size
+
+            if self.color_corrections is not None and len(self.color_corrections) == 1:
+                self.color_corrections = self.color_corrections * self.batch_size
+
         elif len(imgs) <= self.batch_size:
             self.batch_size = len(imgs)
             batch_images = np.array(imgs)
@@ -697,7 +877,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
 
         self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
 
-        if self.image_mask is not None:
+        if self.resize_mode == 3:
+            self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+
+        if image_mask is not None:
             init_mask = latent_mask
             latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
             latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
@@ -714,10 +897,16 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             elif self.inpainting_fill == 3:
                 self.init_latent = self.init_latent * self.mask
 
-    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+        self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
+
+    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
         x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
 
-        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
+        if self.initial_noise_multiplier != 1.0:
+            self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
+            x *= self.initial_noise_multiplier
+
+        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
 
         if self.mask is not None:
             samples = samples * self.nmask + self.init_latent * self.mask
@@ -725,4 +914,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
         del x
         devices.torch_gc()
 
-        return samples
\ No newline at end of file
+        return samples
diff --git a/modules/safe.py b/modules/safe.py
index 399165a1..82d44be3 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -23,23 +23,30 @@ def encode(*args):
 
 
 class RestrictedUnpickler(pickle.Unpickler):
+    extra_handler = None
+
     def persistent_load(self, saved_id):
         assert saved_id[0] == 'storage'
         return TypedStorage()
 
     def find_class(self, module, name):
+        if self.extra_handler is not None:
+            res = self.extra_handler(module, name)
+            if res is not None:
+                return res
+
         if module == 'collections' and name == 'OrderedDict':
             return getattr(collections, name)
-        if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
+        if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
             return getattr(torch._utils, name)
-        if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
+        if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
             return getattr(torch, name)
         if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
             return getattr(torch.nn.modules.container, name)
-        if module == 'numpy.core.multiarray' and name == 'scalar':
-            return numpy.core.multiarray.scalar
-        if module == 'numpy' and name == 'dtype':
-            return numpy.dtype
+        if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
+            return getattr(numpy.core.multiarray, name)
+        if module == 'numpy' and name in ['dtype', 'ndarray']:
+            return getattr(numpy, name)
         if module == '_codecs' and name == 'encode':
             return encode
         if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
@@ -52,32 +59,37 @@ class RestrictedUnpickler(pickle.Unpickler):
             return set
 
         # Forbid everything else.
-        raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
+        raise Exception(f"global '{module}/{name}' is forbidden")
 
 
-allowed_zip_names = ["archive/data.pkl", "archive/version"]
-allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
-
+# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
+allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
+data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
 
 def check_zip_filenames(filename, names):
     for name in names:
-        if name in allowed_zip_names:
-            continue
         if allowed_zip_names_re.match(name):
             continue
 
         raise Exception(f"bad file inside {filename}: {name}")
 
 
-def check_pt(filename):
+def check_pt(filename, extra_handler):
     try:
 
         # new pytorch format is a zip file
         with zipfile.ZipFile(filename) as z:
             check_zip_filenames(filename, z.namelist())
 
-            with z.open('archive/data.pkl') as file:
+            # find filename of data.pkl in zip file: '<directory name>/data.pkl'
+            data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
+            if len(data_pkl_filenames) == 0:
+                raise Exception(f"data.pkl not found in {filename}")
+            if len(data_pkl_filenames) > 1:
+                raise Exception(f"Multiple data.pkl found in {filename}")
+            with z.open(data_pkl_filenames[0]) as file:
                 unpickler = RestrictedUnpickler(file)
+                unpickler.extra_handler = extra_handler
                 unpickler.load()
 
     except zipfile.BadZipfile:
@@ -85,33 +97,96 @@ def check_pt(filename):
         # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
         with open(filename, "rb") as file:
             unpickler = RestrictedUnpickler(file)
+            unpickler.extra_handler = extra_handler
             for i in range(5):
                 unpickler.load()
 
 
 def load(filename, *args, **kwargs):
+    return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
+
+
+def load_with_extra(filename, extra_handler=None, *args, **kwargs):
+    """
+    this function is intended to be used by extensions that want to load models with
+    some extra classes in them that the usual unpickler would find suspicious.
+
+    Use the extra_handler argument to specify a function that takes module and field name as text,
+    and returns that field's value:
+
+    ```python
+    def extra(module, name):
+        if module == 'collections' and name == 'OrderedDict':
+            return collections.OrderedDict
+
+        return None
+
+    safe.load_with_extra('model.pt', extra_handler=extra)
+    ```
+
+    The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
+    definitely unsafe.
+    """
+
     from modules import shared
 
     try:
         if not shared.cmd_opts.disable_safe_unpickle:
-            check_pt(filename)
+            check_pt(filename, extra_handler)
 
     except pickle.UnpicklingError:
         print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
         print(traceback.format_exc(), file=sys.stderr)
-        print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
-        print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+        print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
+        print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
         return None
 
     except Exception:
         print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
         print(traceback.format_exc(), file=sys.stderr)
-        print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
-        print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+        print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
+        print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
         return None
 
     return unsafe_torch_load(filename, *args, **kwargs)
 
 
+class Extra:
+    """
+    A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
+    (because it's not your code making the torch.load call). The intended use is like this:
+
+```
+import torch
+from modules import safe
+
+def handler(module, name):
+    if module == 'torch' and name in ['float64', 'float16']:
+        return getattr(torch, name)
+
+    return None
+
+with safe.Extra(handler):
+    x = torch.load('model.pt')
+```
+    """
+
+    def __init__(self, handler):
+        self.handler = handler
+
+    def __enter__(self):
+        global global_extra_handler
+
+        assert global_extra_handler is None, 'already inside an Extra() block'
+        global_extra_handler = self.handler
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        global global_extra_handler
+
+        global_extra_handler = None
+
+
 unsafe_torch_load = torch.load
 torch.load = load
+global_extra_handler = None
+
diff --git a/modules/safety.py b/modules/safety.py
deleted file mode 100644
index cff4b278..00000000
--- a/modules/safety.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import torch
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from transformers import AutoFeatureExtractor
-from PIL import Image
-
-import modules.shared as shared
-
-safety_model_id = "CompVis/stable-diffusion-safety-checker"
-safety_feature_extractor = None
-safety_checker = None
-
-def numpy_to_pil(images):
-    """
-    Convert a numpy image or a batch of images to a PIL image.
-    """
-    if images.ndim == 3:
-        images = images[None, ...]
-    images = (images * 255).round().astype("uint8")
-    pil_images = [Image.fromarray(image) for image in images]
-
-    return pil_images
-
-# check and replace nsfw content
-def check_safety(x_image):
-    global safety_feature_extractor, safety_checker
-
-    if safety_feature_extractor is None:
-        safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
-        safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
-
-    safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
-    x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
-
-    return x_checked_image, has_nsfw_concept
-
-
-def censor_batch(x):
-    x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
-    x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
-    x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
-
-    return x
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
new file mode 100644
index 00000000..de69fd9f
--- /dev/null
+++ b/modules/script_callbacks.py
@@ -0,0 +1,281 @@
+import sys
+import traceback
+from collections import namedtuple
+import inspect
+from typing import Optional
+
+from fastapi import FastAPI
+from gradio import Blocks
+
+
+def report_exception(c, job):
+    print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
+    print(traceback.format_exc(), file=sys.stderr)
+
+
+class ImageSaveParams:
+    def __init__(self, image, p, filename, pnginfo):
+        self.image = image
+        """the PIL image itself"""
+
+        self.p = p
+        """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
+
+        self.filename = filename
+        """name of file that the image would be saved to"""
+
+        self.pnginfo = pnginfo
+        """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
+
+
+class CFGDenoiserParams:
+    def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
+        self.x = x
+        """Latent image representation in the process of being denoised"""
+        
+        self.image_cond = image_cond
+        """Conditioning image"""
+        
+        self.sigma = sigma
+        """Current sigma noise step value"""
+        
+        self.sampling_step = sampling_step
+        """Current Sampling step number"""
+        
+        self.total_sampling_steps = total_sampling_steps
+        """Total number of sampling steps planned"""
+
+
+class UiTrainTabParams:
+    def __init__(self, txt2img_preview_params):
+        self.txt2img_preview_params = txt2img_preview_params
+
+
+class ImageGridLoopParams:
+    def __init__(self, imgs, cols, rows):
+        self.imgs = imgs
+        self.cols = cols
+        self.rows = rows
+
+
+ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
+callback_map = dict(
+    callbacks_app_started=[],
+    callbacks_model_loaded=[],
+    callbacks_ui_tabs=[],
+    callbacks_ui_train_tabs=[],
+    callbacks_ui_settings=[],
+    callbacks_before_image_saved=[],
+    callbacks_image_saved=[],
+    callbacks_cfg_denoiser=[],
+    callbacks_before_component=[],
+    callbacks_after_component=[],
+    callbacks_image_grid=[],
+)
+
+
+def clear_callbacks():
+    for callback_list in callback_map.values():
+        callback_list.clear()
+
+
+def app_started_callback(demo: Optional[Blocks], app: FastAPI):
+    for c in callback_map['callbacks_app_started']:
+        try:
+            c.callback(demo, app)
+        except Exception:
+            report_exception(c, 'app_started_callback')
+
+
+def model_loaded_callback(sd_model):
+    for c in callback_map['callbacks_model_loaded']:
+        try:
+            c.callback(sd_model)
+        except Exception:
+            report_exception(c, 'model_loaded_callback')
+
+
+def ui_tabs_callback():
+    res = []
+
+    for c in callback_map['callbacks_ui_tabs']:
+        try:
+            res += c.callback() or []
+        except Exception:
+            report_exception(c, 'ui_tabs_callback')
+
+    return res
+
+
+def ui_train_tabs_callback(params: UiTrainTabParams):
+    for c in callback_map['callbacks_ui_train_tabs']:
+        try:
+            c.callback(params)
+        except Exception:
+            report_exception(c, 'callbacks_ui_train_tabs')
+
+
+def ui_settings_callback():
+    for c in callback_map['callbacks_ui_settings']:
+        try:
+            c.callback()
+        except Exception:
+            report_exception(c, 'ui_settings_callback')
+
+
+def before_image_saved_callback(params: ImageSaveParams):
+    for c in callback_map['callbacks_before_image_saved']:
+        try:
+            c.callback(params)
+        except Exception:
+            report_exception(c, 'before_image_saved_callback')
+
+
+def image_saved_callback(params: ImageSaveParams):
+    for c in callback_map['callbacks_image_saved']:
+        try:
+            c.callback(params)
+        except Exception:
+            report_exception(c, 'image_saved_callback')
+
+
+def cfg_denoiser_callback(params: CFGDenoiserParams):
+    for c in callback_map['callbacks_cfg_denoiser']:
+        try:
+            c.callback(params)
+        except Exception:
+            report_exception(c, 'cfg_denoiser_callback')
+
+
+def before_component_callback(component, **kwargs):
+    for c in callback_map['callbacks_before_component']:
+        try:
+            c.callback(component, **kwargs)
+        except Exception:
+            report_exception(c, 'before_component_callback')
+
+
+def after_component_callback(component, **kwargs):
+    for c in callback_map['callbacks_after_component']:
+        try:
+            c.callback(component, **kwargs)
+        except Exception:
+            report_exception(c, 'after_component_callback')
+
+
+def image_grid_callback(params: ImageGridLoopParams):
+    for c in callback_map['callbacks_image_grid']:
+        try:
+            c.callback(params)
+        except Exception:
+            report_exception(c, 'image_grid')
+
+
+def add_callback(callbacks, fun):
+    stack = [x for x in inspect.stack() if x.filename != __file__]
+    filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+
+    callbacks.append(ScriptCallback(filename, fun))
+
+    
+def remove_current_script_callbacks():
+    stack = [x for x in inspect.stack() if x.filename != __file__]
+    filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+    if filename == 'unknown file':
+        return
+    for callback_list in callback_map.values():
+        for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
+            callback_list.remove(callback_to_remove)
+
+
+def remove_callbacks_for_function(callback_func):
+    for callback_list in callback_map.values():
+        for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
+            callback_list.remove(callback_to_remove)
+
+
+def on_app_started(callback):
+    """register a function to be called when the webui started, the gradio `Block` component and
+    fastapi `FastAPI` object are passed as the arguments"""
+    add_callback(callback_map['callbacks_app_started'], callback)
+
+
+def on_model_loaded(callback):
+    """register a function to be called when the stable diffusion model is created; the model is
+    passed as an argument"""
+    add_callback(callback_map['callbacks_model_loaded'], callback)
+
+
+def on_ui_tabs(callback):
+    """register a function to be called when the UI is creating new tabs.
+    The function must either return a None, which means no new tabs to be added, or a list, where
+    each element is a tuple:
+        (gradio_component, title, elem_id)
+
+    gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
+    title is tab text displayed to user in the UI
+    elem_id is HTML id for the tab
+    """
+    add_callback(callback_map['callbacks_ui_tabs'], callback)
+
+
+def on_ui_train_tabs(callback):
+    """register a function to be called when the UI is creating new tabs for the train tab.
+    Create your new tabs with gr.Tab.
+    """
+    add_callback(callback_map['callbacks_ui_train_tabs'], callback)
+
+
+def on_ui_settings(callback):
+    """register a function to be called before UI settings are populated; add your settings
+    by using shared.opts.add_option(shared.OptionInfo(...)) """
+    add_callback(callback_map['callbacks_ui_settings'], callback)
+
+
+def on_before_image_saved(callback):
+    """register a function to be called before an image is saved to a file.
+    The callback is called with one argument:
+        - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
+    """
+    add_callback(callback_map['callbacks_before_image_saved'], callback)
+
+
+def on_image_saved(callback):
+    """register a function to be called after an image is saved to a file.
+    The callback is called with one argument:
+        - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
+    """
+    add_callback(callback_map['callbacks_image_saved'], callback)
+
+
+def on_cfg_denoiser(callback):
+    """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
+    The callback is called with one argument:
+        - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
+    """
+    add_callback(callback_map['callbacks_cfg_denoiser'], callback)
+
+
+def on_before_component(callback):
+    """register a function to be called before a component is created.
+    The callback is called with arguments:
+        - component - gradio component that is about to be created.
+        - **kwargs - args to gradio.components.IOComponent.__init__ function
+
+    Use elem_id/label fields of kwargs to figure out which component it is.
+    This can be useful to inject your own components somewhere in the middle of vanilla UI.
+    """
+    add_callback(callback_map['callbacks_before_component'], callback)
+
+
+def on_after_component(callback):
+    """register a function to be called after a component is created. See on_before_component for more."""
+    add_callback(callback_map['callbacks_after_component'], callback)
+
+
+def on_image_grid(callback):
+    """register a function to be called before making an image grid.
+    The callback is called with one argument:
+       - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
+    """
+    add_callback(callback_map['callbacks_image_grid'], callback)
diff --git a/modules/script_loading.py b/modules/script_loading.py
new file mode 100644
index 00000000..f93f0951
--- /dev/null
+++ b/modules/script_loading.py
@@ -0,0 +1,34 @@
+import os
+import sys
+import traceback
+from types import ModuleType
+
+
+def load_module(path):
+    with open(path, "r", encoding="utf8") as file:
+        text = file.read()
+
+    compiled = compile(text, path, 'exec')
+    module = ModuleType(os.path.basename(path))
+    exec(compiled, module.__dict__)
+
+    return module
+
+
+def preload_extensions(extensions_dir, parser):
+    if not os.path.isdir(extensions_dir):
+        return
+
+    for dirname in sorted(os.listdir(extensions_dir)):
+        preload_script = os.path.join(extensions_dir, dirname, "preload.py")
+        if not os.path.isfile(preload_script):
+            continue
+
+        try:
+            module = load_module(preload_script)
+            if hasattr(module, 'preload'):
+                module.preload(parser)
+
+        except Exception:
+            print(f"Error running preload() for {preload_script}", file=sys.stderr)
+            print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/scripts.py b/modules/scripts.py
index 1039fa9c..722f8685 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,86 +1,211 @@
 import os
 import sys
 import traceback
+from collections import namedtuple
 
-import modules.ui as ui
 import gradio as gr
 
 from modules.processing import StableDiffusionProcessing
-from modules import shared
+from modules import shared, paths, script_callbacks, extensions, script_loading
+
+AlwaysVisible = object()
+
 
 class Script:
     filename = None
     args_from = None
     args_to = None
+    alwayson = False
+
+    is_txt2img = False
+    is_img2img = False
+
+    """A gr.Group component that has all script's UI inside it"""
+    group = None
+
+    infotext_fields = None
+    """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
+    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
+    """
 
-    # The title of the script. This is what will be displayed in the dropdown menu.
     def title(self):
+        """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
+
         raise NotImplementedError()
 
-    # How the script is displayed in the UI. See https://gradio.app/docs/#components
-    # for the different UI components you can use and how to create them.
-    # Most UI components can return a value, such as a boolean for a checkbox.
-    # The returned values are passed to the run method as parameters.
     def ui(self, is_img2img):
+        """this function should create gradio UI elements. See https://gradio.app/docs/#components
+        The return value should be an array of all components that are used in processing.
+        Values of those returned components will be passed to run() and process() functions.
+        """
+
         pass
 
-    # Determines when the script should be shown in the dropdown menu via the 
-    # returned value. As an example:
-    # is_img2img is True if the current tab is img2img, and False if it is txt2img.
-    # Thus, return is_img2img to only show the script on the img2img tab.
     def show(self, is_img2img):
+        """
+        is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
+
+        This function should return:
+         - False if the script should not be shown in UI at all
+         - True if the script should be shown in UI if it's selected in the scripts dropdown
+         - script.AlwaysVisible if the script should be shown in UI at all times
+         """
+
         return True
 
-    # This is where the additional processing is implemented. The parameters include
-    # self, the model object "p" (a StableDiffusionProcessing class, see
-    # processing.py), and the parameters returned by the ui method.
-    # Custom functions can be defined here, and additional libraries can be imported 
-    # to be used in processing. The return value should be a Processed object, which is
-    # what is returned by the process_images method.
-    def run(self, *args):
+    def run(self, p, *args):
+        """
+        This function is called if the script has been selected in the script dropdown.
+        It must do all processing and return the Processed object with results, same as
+        one returned by processing.process_images.
+
+        Usually the processing is done by calling the processing.process_images function.
+
+        args contains all values returned by components from ui()
+        """
+
         raise NotImplementedError()
 
-    # The description method is currently unused.
-    # To add a description that appears when hovering over the title, amend the "titles" 
-    # dict in script.js to include the script title (returned by title) as a key, and 
-    # your description as the value.
+    def process(self, p, *args):
+        """
+        This function is called before processing begins for AlwaysVisible scripts.
+        You can modify the processing object (p) here, inject hooks, etc.
+        args contains all values returned by components from ui()
+        """
+
+        pass
+
+    def process_batch(self, p, *args, **kwargs):
+        """
+        Same as process(), but called for every batch.
+
+        **kwargs will have those items:
+          - batch_number - index of current batch, from 0 to number of batches-1
+          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+          - seeds - list of seeds for current batch
+          - subseeds - list of subseeds for current batch
+        """
+
+        pass
+
+    def postprocess_batch(self, p, *args, **kwargs):
+        """
+        Same as process_batch(), but called for every batch after it has been generated.
+
+        **kwargs will have same items as process_batch, and also:
+          - batch_number - index of current batch, from 0 to number of batches-1
+          - images - torch tensor with all generated images, with values ranging from 0 to 1;
+        """
+
+        pass
+
+    def postprocess(self, p, processed, *args):
+        """
+        This function is called after processing ends for AlwaysVisible scripts.
+        args contains all values returned by components from ui()
+        """
+
+        pass
+
+    def before_component(self, component, **kwargs):
+        """
+        Called before a component is created.
+        Use elem_id/label fields of kwargs to figure out which component it is.
+        This can be useful to inject your own components somewhere in the middle of vanilla UI.
+        You can return created components in the ui() function to add them to the list of arguments for your processing functions
+        """
+
+        pass
+
+    def after_component(self, component, **kwargs):
+        """
+        Called after a component is created. Same as above.
+        """
+
+        pass
+
     def describe(self):
+        """unused"""
         return ""
 
 
+current_basedir = paths.script_path
+
+
+def basedir():
+    """returns the base directory for the current script. For scripts in the main scripts directory,
+    this is the main directory (where webui.py resides), and for scripts in extensions directory
+    (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
+    """
+    return current_basedir
+
+
 scripts_data = []
+ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
+ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
 
 
-def load_scripts(basedir):
-    if not os.path.exists(basedir):
-        return
+def list_scripts(scriptdirname, extension):
+    scripts_list = []
 
-    for filename in sorted(os.listdir(basedir)):
-        path = os.path.join(basedir, filename)
+    basedir = os.path.join(paths.script_path, scriptdirname)
+    if os.path.exists(basedir):
+        for filename in sorted(os.listdir(basedir)):
+            scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
 
-        if os.path.splitext(path)[1].lower() != '.py':
+    for ext in extensions.active():
+        scripts_list += ext.list_files(scriptdirname, extension)
+
+    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+
+    return scripts_list
+
+
+def list_files_with_name(filename):
+    res = []
+
+    dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
+
+    for dirpath in dirs:
+        if not os.path.isdir(dirpath):
             continue
 
-        if not os.path.isfile(path):
-            continue
+        path = os.path.join(dirpath, filename)
+        if os.path.isfile(path):
+            res.append(path)
 
+    return res
+
+
+def load_scripts():
+    global current_basedir
+    scripts_data.clear()
+    script_callbacks.clear_callbacks()
+
+    scripts_list = list_scripts("scripts", ".py")
+
+    syspath = sys.path
+
+    for scriptfile in sorted(scripts_list):
         try:
-            with open(path, "r", encoding="utf8") as file:
-                text = file.read()
+            if scriptfile.basedir != paths.script_path:
+                sys.path = [scriptfile.basedir] + sys.path
+            current_basedir = scriptfile.basedir
 
-            from types import ModuleType
-            compiled = compile(text, path, 'exec')
-            module = ModuleType(filename)
-            exec(compiled, module.__dict__)
+            module = script_loading.load_module(scriptfile.path)
 
             for key, script_class in module.__dict__.items():
                 if type(script_class) == type and issubclass(script_class, Script):
-                    scripts_data.append((script_class, path))
+                    scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
 
         except Exception:
-            print(f"Error loading script: {filename}", file=sys.stderr)
+            print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
             print(traceback.format_exc(), file=sys.stderr)
 
+        finally:
+            sys.path = syspath
+            current_basedir = paths.script_path
+
 
 def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
     try:
@@ -96,64 +221,94 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
 class ScriptRunner:
     def __init__(self):
         self.scripts = []
+        self.selectable_scripts = []
+        self.alwayson_scripts = []
         self.titles = []
+        self.infotext_fields = []
 
-    def setup_ui(self, is_img2img):
-        for script_class, path in scripts_data:
+    def initialize_scripts(self, is_img2img):
+        self.scripts.clear()
+        self.alwayson_scripts.clear()
+        self.selectable_scripts.clear()
+
+        for script_class, path, basedir in scripts_data:
             script = script_class()
             script.filename = path
+            script.is_txt2img = not is_img2img
+            script.is_img2img = is_img2img
 
-            if not script.show(is_img2img):
-                continue
+            visibility = script.show(script.is_img2img)
 
-            self.scripts.append(script)
+            if visibility == AlwaysVisible:
+                self.scripts.append(script)
+                self.alwayson_scripts.append(script)
+                script.alwayson = True
 
-        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
+            elif visibility:
+                self.scripts.append(script)
+                self.selectable_scripts.append(script)
 
-        dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
-        dropdown.save_to_config = True
-        inputs = [dropdown]
+    def setup_ui(self):
+        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
 
-        for script in self.scripts:
+        inputs = [None]
+        inputs_alwayson = [True]
+
+        def create_script_ui(script, inputs, inputs_alwayson):
             script.args_from = len(inputs)
             script.args_to = len(inputs)
 
-            controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
+            controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
 
             if controls is None:
-                continue
+                return
 
             for control in controls:
                 control.custom_script_source = os.path.basename(script.filename)
-                control.visible = False
+
+            if script.infotext_fields is not None:
+                self.infotext_fields += script.infotext_fields
 
             inputs += controls
+            inputs_alwayson += [script.alwayson for _ in controls]
             script.args_to = len(inputs)
 
-        def select_script(script_index):
-            if 0 < script_index <= len(self.scripts):
-                script = self.scripts[script_index-1]
-                args_from = script.args_from
-                args_to = script.args_to
-            else:
-                args_from = 0
-                args_to = 0
+        for script in self.alwayson_scripts:
+            with gr.Group() as group:
+                create_script_ui(script, inputs, inputs_alwayson)
 
-            return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
+            script.group = group
+
+        dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
+        dropdown.save_to_config = True
+        inputs[0] = dropdown
+
+        for script in self.selectable_scripts:
+            with gr.Group(visible=False) as group:
+                create_script_ui(script, inputs, inputs_alwayson)
+
+            script.group = group
+
+        def select_script(script_index):
+            selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
+
+            return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
 
         def init_field(title):
+            """called when an initial value is set from ui-config.json to show script's UI components"""
+
             if title == 'None':
                 return
+
             script_index = self.titles.index(title)
-            script = self.scripts[script_index]
-            for i in range(script.args_from, script.args_to):
-                inputs[i].visible = True
+            self.selectable_scripts[script_index].group.visible = True
 
         dropdown.init_field = init_field
+
         dropdown.change(
             fn=select_script,
             inputs=[dropdown],
-            outputs=inputs
+            outputs=[script.group for script in self.selectable_scripts]
         )
 
         return inputs
@@ -164,7 +319,7 @@ class ScriptRunner:
         if script_index == 0:
             return None
 
-        script = self.scripts[script_index-1]
+        script = self.selectable_scripts[script_index-1]
 
         if script is None:
             return None
@@ -176,40 +331,112 @@ class ScriptRunner:
 
         return processed
 
-    def reload_sources(self):
+    def process(self, p):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.process(p, *script_args)
+            except Exception:
+                print(f"Error running process: {script.filename}", file=sys.stderr)
+                print(traceback.format_exc(), file=sys.stderr)
+
+    def process_batch(self, p, **kwargs):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.process_batch(p, *script_args, **kwargs)
+            except Exception:
+                print(f"Error running process_batch: {script.filename}", file=sys.stderr)
+                print(traceback.format_exc(), file=sys.stderr)
+
+    def postprocess(self, p, processed):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.postprocess(p, processed, *script_args)
+            except Exception:
+                print(f"Error running postprocess: {script.filename}", file=sys.stderr)
+                print(traceback.format_exc(), file=sys.stderr)
+
+    def postprocess_batch(self, p, images, **kwargs):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.postprocess_batch(p, *script_args, images=images, **kwargs)
+            except Exception:
+                print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
+                print(traceback.format_exc(), file=sys.stderr)
+
+    def before_component(self, component, **kwargs):
+        for script in self.scripts:
+            try:
+                script.before_component(component, **kwargs)
+            except Exception:
+                print(f"Error running before_component: {script.filename}", file=sys.stderr)
+                print(traceback.format_exc(), file=sys.stderr)
+
+    def after_component(self, component, **kwargs):
+        for script in self.scripts:
+            try:
+                script.after_component(component, **kwargs)
+            except Exception:
+                print(f"Error running after_component: {script.filename}", file=sys.stderr)
+                print(traceback.format_exc(), file=sys.stderr)
+
+    def reload_sources(self, cache):
         for si, script in list(enumerate(self.scripts)):
-            with open(script.filename, "r", encoding="utf8") as file:
-                args_from = script.args_from
-                args_to = script.args_to
-                filename = script.filename
-                text = file.read()
+            args_from = script.args_from
+            args_to = script.args_to
+            filename = script.filename
 
-                from types import ModuleType
+            module = cache.get(filename, None)
+            if module is None:
+                module = script_loading.load_module(script.filename)
+                cache[filename] = module
 
-                compiled = compile(text, filename, 'exec')
-                module = ModuleType(script.filename)
-                exec(compiled, module.__dict__)
+            for key, script_class in module.__dict__.items():
+                if type(script_class) == type and issubclass(script_class, Script):
+                    self.scripts[si] = script_class()
+                    self.scripts[si].filename = filename
+                    self.scripts[si].args_from = args_from
+                    self.scripts[si].args_to = args_to
 
-                for key, script_class in module.__dict__.items():
-                    if type(script_class) == type and issubclass(script_class, Script):
-                        self.scripts[si] = script_class()
-                        self.scripts[si].filename = filename
-                        self.scripts[si].args_from = args_from
-                        self.scripts[si].args_to = args_to
 
 scripts_txt2img = ScriptRunner()
 scripts_img2img = ScriptRunner()
+scripts_current: ScriptRunner = None
+
 
 def reload_script_body_only():
-    scripts_txt2img.reload_sources()
-    scripts_img2img.reload_sources()
+    cache = {}
+    scripts_txt2img.reload_sources(cache)
+    scripts_img2img.reload_sources(cache)
 
 
-def reload_scripts(basedir):
+def reload_scripts():
     global scripts_txt2img, scripts_img2img
 
-    scripts_data.clear()
-    load_scripts(basedir)
+    load_scripts()
 
     scripts_txt2img = ScriptRunner()
     scripts_img2img = ScriptRunner()
+
+
+def IOComponent_init(self, *args, **kwargs):
+    if scripts_current is not None:
+        scripts_current.before_component(self, **kwargs)
+
+    script_callbacks.before_component_callback(self, **kwargs)
+
+    res = original_IOComponent_init(self, *args, **kwargs)
+
+    script_callbacks.after_component_callback(self, **kwargs)
+
+    if scripts_current is not None:
+        scripts_current.after_component(self, **kwargs)
+
+    return res
+
+
+original_IOComponent_init = gr.components.IOComponent.__init__
+gr.components.IOComponent.__init__ = IOComponent_init
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 984b35c4..fa2cd4bb 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -1,60 +1,81 @@
-import math
-import os
-import sys
-import traceback
 import torch
-import numpy as np
-from torch import einsum
 from torch.nn.functional import silu
 
 import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared
-from modules.shared import opts, device, cmd_opts
+from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
+from modules.hypernetworks import hypernetwork
+from modules.shared import cmd_opts
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+
 from modules.sd_hijack_optimizations import invokeAI_mps_available
 
 import ldm.modules.attention
 import ldm.modules.diffusionmodules.model
+import ldm.modules.diffusionmodules.openaimodel
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
+import ldm.modules.encoders.modules
 
 attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
 diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
 diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
 
+# new memory efficient cross attention blocks do not support hypernets and we already
+# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
+ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
+ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
+
+# silence new console spam from SD2
+ldm.modules.attention.print = lambda *args: None
+ldm.modules.diffusionmodules.model.print = lambda *args: None
+
+
 def apply_optimizations():
     undo_optimizations()
 
     ldm.modules.diffusionmodules.model.nonlinearity = silu
+    ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+    
+    optimization_method = None
 
     if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
         print("Applying xformers cross attention optimization.")
         ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
         ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
+        optimization_method = 'xformers'
     elif cmd_opts.opt_split_attention_v1:
         print("Applying v1 cross attention optimization.")
         ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+        optimization_method = 'V1'
     elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
         if not invokeAI_mps_available and shared.device.type == 'mps':
             print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
             print("Applying v1 cross attention optimization.")
             ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+            optimization_method = 'V1'
         else:
             print("Applying cross attention optimization (InvokeAI).")
             ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+            optimization_method = 'InvokeAI'
     elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
         print("Applying cross attention optimization (Doggettx).")
         ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
         ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
+        optimization_method = 'Doggettx'
+
+    return optimization_method
 
 
 def undo_optimizations():
-    from modules.hypernetworks import hypernetwork
-
     ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
     ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
     ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
 
 
-def get_target_prompt_token_count(token_count):
-    return math.ceil(max(token_count, 1) / 75) * 75
+def fix_checkpoint():
+    ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
+    ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
+    ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
 
 
 class StableDiffusionModelHijack:
@@ -63,18 +84,31 @@ class StableDiffusionModelHijack:
     layers = None
     circular_enabled = False
     clip = None
+    optimization_method = None
 
     embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
 
     def hijack(self, m):
-        model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
 
-        model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
-        m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+        if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+            model_embeddings = m.cond_stage_model.roberta.embeddings
+            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
+            m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
+
+        elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
+            model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
+            m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
+        elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
+            m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
+            m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
+        self.optimization_method = apply_optimizations()
 
         self.clip = m.cond_stage_model
-
-        apply_optimizations()
+        
+        fix_checkpoint()
 
         def flatten(el):
             flattened = [flatten(children) for children in el.children()]
@@ -86,12 +120,23 @@ class StableDiffusionModelHijack:
         self.layers = flatten(m)
 
     def undo_hijack(self, m):
-        if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
+
+        if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+            m.cond_stage_model = m.cond_stage_model.wrapped 
+
+        elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
             m.cond_stage_model = m.cond_stage_model.wrapped
 
-        model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
-        if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
-            model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
+            model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+            if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
+                model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
+        elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
+            m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
+            m.cond_stage_model = m.cond_stage_model.wrapped
+
+        self.apply_circular(False)
+        self.layers = None
+        self.clip = None
 
     def apply_circular(self, enable):
         if self.circular_enabled == enable:
@@ -107,263 +152,8 @@ class StableDiffusionModelHijack:
 
     def tokenize(self, text):
         _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
-        return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
 
-
-class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
-    def __init__(self, wrapped, hijack):
-        super().__init__()
-        self.wrapped = wrapped
-        self.hijack: StableDiffusionModelHijack = hijack
-        self.tokenizer = wrapped.tokenizer
-        self.token_mults = {}
-
-        self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
-
-        tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
-        for text, ident in tokens_with_parens:
-            mult = 1.0
-            for c in text:
-                if c == '[':
-                    mult /= 1.1
-                if c == ']':
-                    mult *= 1.1
-                if c == '(':
-                    mult *= 1.1
-                if c == ')':
-                    mult /= 1.1
-
-            if mult != 1.0:
-                self.token_mults[ident] = mult
-
-    def tokenize_line(self, line, used_custom_terms, hijack_comments):
-        id_end = self.wrapped.tokenizer.eos_token_id
-
-        if opts.enable_emphasis:
-            parsed = prompt_parser.parse_prompt_attention(line)
-        else:
-            parsed = [[line, 1.0]]
-
-        tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
-
-        fixes = []
-        remade_tokens = []
-        multipliers = []
-        last_comma = -1
-
-        for tokens, (text, weight) in zip(tokenized, parsed):
-            i = 0
-            while i < len(tokens):
-                token = tokens[i]
-
-                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
-                if token == self.comma_token:
-                    last_comma = len(remade_tokens)
-                elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
-                    last_comma += 1
-                    reloc_tokens = remade_tokens[last_comma:]
-                    reloc_mults = multipliers[last_comma:]
-
-                    remade_tokens = remade_tokens[:last_comma]
-                    length = len(remade_tokens)
-                    
-                    rem = int(math.ceil(length / 75)) * 75 - length
-                    remade_tokens += [id_end] * rem + reloc_tokens
-                    multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
-                
-                if embedding is None:
-                    remade_tokens.append(token)
-                    multipliers.append(weight)
-                    i += 1
-                else:
-                    emb_len = int(embedding.vec.shape[0])
-                    iteration = len(remade_tokens) // 75
-                    if (len(remade_tokens) + emb_len) // 75 != iteration:
-                        rem = (75 * (iteration + 1) - len(remade_tokens))
-                        remade_tokens += [id_end] * rem
-                        multipliers += [1.0] * rem
-                        iteration += 1
-                    fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
-                    remade_tokens += [0] * emb_len
-                    multipliers += [weight] * emb_len
-                    used_custom_terms.append((embedding.name, embedding.checksum()))
-                    i += embedding_length_in_tokens
-
-        token_count = len(remade_tokens)
-        prompt_target_length = get_target_prompt_token_count(token_count)
-        tokens_to_add = prompt_target_length - len(remade_tokens)
-
-        remade_tokens = remade_tokens + [id_end] * tokens_to_add
-        multipliers = multipliers + [1.0] * tokens_to_add
-
-        return remade_tokens, fixes, multipliers, token_count
-
-    def process_text(self, texts):
-        used_custom_terms = []
-        remade_batch_tokens = []
-        hijack_comments = []
-        hijack_fixes = []
-        token_count = 0
-
-        cache = {}
-        batch_multipliers = []
-        for line in texts:
-            if line in cache:
-                remade_tokens, fixes, multipliers = cache[line]
-            else:
-                remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
-                token_count = max(current_token_count, token_count)
-
-                cache[line] = (remade_tokens, fixes, multipliers)
-
-            remade_batch_tokens.append(remade_tokens)
-            hijack_fixes.append(fixes)
-            batch_multipliers.append(multipliers)
-
-        return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
-
-    def process_text_old(self, text):
-        id_start = self.wrapped.tokenizer.bos_token_id
-        id_end = self.wrapped.tokenizer.eos_token_id
-        maxlen = self.wrapped.max_length  # you get to stay at 77
-        used_custom_terms = []
-        remade_batch_tokens = []
-        overflowing_words = []
-        hijack_comments = []
-        hijack_fixes = []
-        token_count = 0
-
-        cache = {}
-        batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
-        batch_multipliers = []
-        for tokens in batch_tokens:
-            tuple_tokens = tuple(tokens)
-
-            if tuple_tokens in cache:
-                remade_tokens, fixes, multipliers = cache[tuple_tokens]
-            else:
-                fixes = []
-                remade_tokens = []
-                multipliers = []
-                mult = 1.0
-
-                i = 0
-                while i < len(tokens):
-                    token = tokens[i]
-
-                    embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
-                    mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
-                    if mult_change is not None:
-                        mult *= mult_change
-                        i += 1
-                    elif embedding is None:
-                        remade_tokens.append(token)
-                        multipliers.append(mult)
-                        i += 1
-                    else:
-                        emb_len = int(embedding.vec.shape[0])
-                        fixes.append((len(remade_tokens), embedding))
-                        remade_tokens += [0] * emb_len
-                        multipliers += [mult] * emb_len
-                        used_custom_terms.append((embedding.name, embedding.checksum()))
-                        i += embedding_length_in_tokens
-
-                if len(remade_tokens) > maxlen - 2:
-                    vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
-                    ovf = remade_tokens[maxlen - 2:]
-                    overflowing_words = [vocab.get(int(x), "") for x in ovf]
-                    overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
-                    hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
-                token_count = len(remade_tokens)
-                remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
-                remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
-                cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
-
-            multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
-            multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
-
-            remade_batch_tokens.append(remade_tokens)
-            hijack_fixes.append(fixes)
-            batch_multipliers.append(multipliers)
-        return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-    
-    def forward(self, text):
-        use_old = opts.use_old_emphasis_implementation
-        if use_old:
-            batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
-        else:
-            batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
-        self.hijack.comments += hijack_comments
-
-        if len(used_custom_terms) > 0:
-            self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-        
-        if use_old:
-            self.hijack.fixes = hijack_fixes
-            return self.process_tokens(remade_batch_tokens, batch_multipliers)
-        
-        z = None
-        i = 0
-        while max(map(len, remade_batch_tokens)) != 0:
-            rem_tokens = [x[75:] for x in remade_batch_tokens]
-            rem_multipliers = [x[75:] for x in batch_multipliers]
-
-            self.hijack.fixes = []
-            for unfiltered in hijack_fixes:
-                fixes = []
-                for fix in unfiltered:
-                    if fix[0] == i:
-                        fixes.append(fix[1])
-                self.hijack.fixes.append(fixes)
-            
-            tokens = []
-            multipliers = []
-            for j in range(len(remade_batch_tokens)):
-                if len(remade_batch_tokens[j]) > 0:
-                    tokens.append(remade_batch_tokens[j][:75])
-                    multipliers.append(batch_multipliers[j][:75])
-                else:
-                    tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
-                    multipliers.append([1.0] * 75)
-
-            z1 = self.process_tokens(tokens, multipliers)
-            z = z1 if z is None else torch.cat((z, z1), axis=-2)
-            
-            remade_batch_tokens = rem_tokens
-            batch_multipliers = rem_multipliers
-            i += 1
-        
-        return z
-        
-    
-    def process_tokens(self, remade_batch_tokens, batch_multipliers):
-        if not opts.use_old_emphasis_implementation:
-            remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
-            batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
-        
-        tokens = torch.asarray(remade_batch_tokens).to(device)
-        outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
-
-        if opts.CLIP_stop_at_last_layers > 1:
-            z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
-            z = self.wrapped.transformer.text_model.final_layer_norm(z)
-        else:
-            z = outputs.last_hidden_state
-
-        # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
-        batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
-        batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
-        original_mean = z.mean()
-        z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
-        new_mean = z.mean()
-        z *= original_mean / new_mean
-
-        return z
+        return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
 
 
 class EmbeddingsWithFixes(torch.nn.Module):
@@ -385,8 +175,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
         for fixes, tensor in zip(batch_fixes, inputs_embeds):
             for offset, embedding in fixes:
                 emb = embedding.vec
-                emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
-                tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
+                emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
+                tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
 
             vecs.append(tensor)
 
@@ -403,3 +193,19 @@ def add_circular_option_to_conv_2d():
 
 
 model_hijack = StableDiffusionModelHijack()
+
+
+def register_buffer(self, name, attr):
+    """
+    Fix register buffer bug for Mac OS.
+    """
+
+    if type(attr) == torch.Tensor:
+        if attr.device != devices.device:
+            attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
+
+    setattr(self, name, attr)
+
+
+ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
+ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py
new file mode 100644
index 00000000..5712972f
--- /dev/null
+++ b/modules/sd_hijack_checkpoint.py
@@ -0,0 +1,10 @@
+from torch.utils.checkpoint import checkpoint
+
+def BasicTransformerBlock_forward(self, x, context=None):
+    return checkpoint(self._forward, x, context)
+
+def AttentionBlock_forward(self, x):
+    return checkpoint(self._forward, x)
+
+def ResBlock_forward(self, x, emb):
+    return checkpoint(self._forward, x, emb)
\ No newline at end of file
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
new file mode 100644
index 00000000..ca92b142
--- /dev/null
+++ b/modules/sd_hijack_clip.py
@@ -0,0 +1,303 @@
+import math
+
+import torch
+
+from modules import prompt_parser, devices
+from modules.shared import opts
+
+def get_target_prompt_token_count(token_count):
+    return math.ceil(max(token_count, 1) / 75) * 75
+
+
+class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
+    def __init__(self, wrapped, hijack):
+        super().__init__()
+        self.wrapped = wrapped
+        self.hijack = hijack
+
+    def tokenize(self, texts):
+        raise NotImplementedError
+
+    def encode_with_transformers(self, tokens):
+        raise NotImplementedError
+
+    def encode_embedding_init_text(self, init_text, nvpt):
+        raise NotImplementedError
+
+    def tokenize_line(self, line, used_custom_terms, hijack_comments):
+        if opts.enable_emphasis:
+            parsed = prompt_parser.parse_prompt_attention(line)
+        else:
+            parsed = [[line, 1.0]]
+
+        tokenized = self.tokenize([text for text, _ in parsed])
+
+        fixes = []
+        remade_tokens = []
+        multipliers = []
+        last_comma = -1
+
+        for tokens, (text, weight) in zip(tokenized, parsed):
+            i = 0
+            while i < len(tokens):
+                token = tokens[i]
+
+                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+                if token == self.comma_token:
+                    last_comma = len(remade_tokens)
+                elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
+                    last_comma += 1
+                    reloc_tokens = remade_tokens[last_comma:]
+                    reloc_mults = multipliers[last_comma:]
+
+                    remade_tokens = remade_tokens[:last_comma]
+                    length = len(remade_tokens)
+
+                    rem = int(math.ceil(length / 75)) * 75 - length
+                    remade_tokens += [self.id_end] * rem + reloc_tokens
+                    multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+
+                if embedding is None:
+                    remade_tokens.append(token)
+                    multipliers.append(weight)
+                    i += 1
+                else:
+                    emb_len = int(embedding.vec.shape[0])
+                    iteration = len(remade_tokens) // 75
+                    if (len(remade_tokens) + emb_len) // 75 != iteration:
+                        rem = (75 * (iteration + 1) - len(remade_tokens))
+                        remade_tokens += [self.id_end] * rem
+                        multipliers += [1.0] * rem
+                        iteration += 1
+                    fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
+                    remade_tokens += [0] * emb_len
+                    multipliers += [weight] * emb_len
+                    used_custom_terms.append((embedding.name, embedding.checksum()))
+                    i += embedding_length_in_tokens
+
+        token_count = len(remade_tokens)
+        prompt_target_length = get_target_prompt_token_count(token_count)
+        tokens_to_add = prompt_target_length - len(remade_tokens)
+
+        remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
+        multipliers = multipliers + [1.0] * tokens_to_add
+
+        return remade_tokens, fixes, multipliers, token_count
+
+    def process_text(self, texts):
+        used_custom_terms = []
+        remade_batch_tokens = []
+        hijack_comments = []
+        hijack_fixes = []
+        token_count = 0
+
+        cache = {}
+        batch_multipliers = []
+        for line in texts:
+            if line in cache:
+                remade_tokens, fixes, multipliers = cache[line]
+            else:
+                remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+                token_count = max(current_token_count, token_count)
+
+                cache[line] = (remade_tokens, fixes, multipliers)
+
+            remade_batch_tokens.append(remade_tokens)
+            hijack_fixes.append(fixes)
+            batch_multipliers.append(multipliers)
+
+        return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+    def process_text_old(self, texts):
+        id_start = self.id_start
+        id_end = self.id_end
+        maxlen = self.wrapped.max_length  # you get to stay at 77
+        used_custom_terms = []
+        remade_batch_tokens = []
+        hijack_comments = []
+        hijack_fixes = []
+        token_count = 0
+
+        cache = {}
+        batch_tokens = self.tokenize(texts)
+        batch_multipliers = []
+        for tokens in batch_tokens:
+            tuple_tokens = tuple(tokens)
+
+            if tuple_tokens in cache:
+                remade_tokens, fixes, multipliers = cache[tuple_tokens]
+            else:
+                fixes = []
+                remade_tokens = []
+                multipliers = []
+                mult = 1.0
+
+                i = 0
+                while i < len(tokens):
+                    token = tokens[i]
+
+                    embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+                    mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
+                    if mult_change is not None:
+                        mult *= mult_change
+                        i += 1
+                    elif embedding is None:
+                        remade_tokens.append(token)
+                        multipliers.append(mult)
+                        i += 1
+                    else:
+                        emb_len = int(embedding.vec.shape[0])
+                        fixes.append((len(remade_tokens), embedding))
+                        remade_tokens += [0] * emb_len
+                        multipliers += [mult] * emb_len
+                        used_custom_terms.append((embedding.name, embedding.checksum()))
+                        i += embedding_length_in_tokens
+
+                if len(remade_tokens) > maxlen - 2:
+                    vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+                    ovf = remade_tokens[maxlen - 2:]
+                    overflowing_words = [vocab.get(int(x), "") for x in ovf]
+                    overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+                    hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+                token_count = len(remade_tokens)
+                remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+                remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+                cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+            multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+            multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+            remade_batch_tokens.append(remade_tokens)
+            hijack_fixes.append(fixes)
+            batch_multipliers.append(multipliers)
+        return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+    def forward(self, text):
+        use_old = opts.use_old_emphasis_implementation
+        if use_old:
+            batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
+        else:
+            batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
+
+        self.hijack.comments += hijack_comments
+
+        if len(used_custom_terms) > 0:
+            self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
+        if use_old:
+            self.hijack.fixes = hijack_fixes
+            return self.process_tokens(remade_batch_tokens, batch_multipliers)
+
+        z = None
+        i = 0
+        while max(map(len, remade_batch_tokens)) != 0:
+            rem_tokens = [x[75:] for x in remade_batch_tokens]
+            rem_multipliers = [x[75:] for x in batch_multipliers]
+
+            self.hijack.fixes = []
+            for unfiltered in hijack_fixes:
+                fixes = []
+                for fix in unfiltered:
+                    if fix[0] == i:
+                        fixes.append(fix[1])
+                self.hijack.fixes.append(fixes)
+
+            tokens = []
+            multipliers = []
+            for j in range(len(remade_batch_tokens)):
+                if len(remade_batch_tokens[j]) > 0:
+                    tokens.append(remade_batch_tokens[j][:75])
+                    multipliers.append(batch_multipliers[j][:75])
+                else:
+                    tokens.append([self.id_end] * 75)
+                    multipliers.append([1.0] * 75)
+
+            z1 = self.process_tokens(tokens, multipliers)
+            z = z1 if z is None else torch.cat((z, z1), axis=-2)
+
+            remade_batch_tokens = rem_tokens
+            batch_multipliers = rem_multipliers
+            i += 1
+
+        return z
+
+    def process_tokens(self, remade_batch_tokens, batch_multipliers):
+        if not opts.use_old_emphasis_implementation:
+            remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
+            batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+
+        tokens = torch.asarray(remade_batch_tokens).to(devices.device)
+
+        if self.id_end != self.id_pad:
+            for batch_pos in range(len(remade_batch_tokens)):
+                index = remade_batch_tokens[batch_pos].index(self.id_end)
+                tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
+
+        z = self.encode_with_transformers(tokens)
+
+        # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
+        batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
+        batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
+        original_mean = z.mean()
+        z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+        new_mean = z.mean()
+        z *= original_mean / new_mean
+
+        return z
+
+
+class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
+    def __init__(self, wrapped, hijack):
+        super().__init__(wrapped, hijack)
+        self.tokenizer = wrapped.tokenizer
+
+        vocab = self.tokenizer.get_vocab()
+
+        self.comma_token = vocab.get(',</w>', None)
+
+        self.token_mults = {}
+        tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
+        for text, ident in tokens_with_parens:
+            mult = 1.0
+            for c in text:
+                if c == '[':
+                    mult /= 1.1
+                if c == ']':
+                    mult *= 1.1
+                if c == '(':
+                    mult *= 1.1
+                if c == ')':
+                    mult /= 1.1
+
+            if mult != 1.0:
+                self.token_mults[ident] = mult
+
+        self.id_start = self.wrapped.tokenizer.bos_token_id
+        self.id_end = self.wrapped.tokenizer.eos_token_id
+        self.id_pad = self.id_end
+
+    def tokenize(self, texts):
+        tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
+
+        return tokenized
+
+    def encode_with_transformers(self, tokens):
+        outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+
+        if opts.CLIP_stop_at_last_layers > 1:
+            z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
+            z = self.wrapped.transformer.text_model.final_layer_norm(z)
+        else:
+            z = outputs.last_hidden_state
+
+        return z
+
+    def encode_embedding_init_text(self, init_text, nvpt):
+        embedding_layer = self.wrapped.transformer.text_model.embeddings
+        ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+        embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
+
+        return embedded
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
new file mode 100644
index 00000000..31d2c898
--- /dev/null
+++ b/modules/sd_hijack_inpainting.py
@@ -0,0 +1,111 @@
+import os
+import torch
+
+from einops import repeat
+from omegaconf import ListConfig
+
+import ldm.models.diffusion.ddpm
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
+
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+
+
+@torch.no_grad()
+def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+                  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                  unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
+    b, *_, device = *x.shape, x.device
+
+    def get_model_output(x, t):
+        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+            e_t = self.model.apply_model(x, t, c)
+        else:
+            x_in = torch.cat([x] * 2)
+            t_in = torch.cat([t] * 2)
+
+            if isinstance(c, dict):
+                assert isinstance(unconditional_conditioning, dict)
+                c_in = dict()
+                for k in c:
+                    if isinstance(c[k], list):
+                        c_in[k] = [
+                            torch.cat([unconditional_conditioning[k][i], c[k][i]])
+                            for i in range(len(c[k]))
+                        ]
+                    else:
+                        c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+            else:
+                c_in = torch.cat([unconditional_conditioning, c])
+
+            e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+            e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+        if score_corrector is not None:
+            assert self.model.parameterization == "eps"
+            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+        return e_t
+
+    alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+    alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+    sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+    sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+    def get_x_prev_and_pred_x0(e_t, index):
+        # select parameters corresponding to the currently considered timestep
+        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+        # current prediction for x_0
+        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        if quantize_denoised:
+            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+        if dynamic_threshold is not None:
+            pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+        # direction pointing to x_t
+        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+        return x_prev, pred_x0
+
+    e_t = get_model_output(x, t)
+    if len(old_eps) == 0:
+        # Pseudo Improved Euler (2nd order)
+        x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+        e_t_next = get_model_output(x_prev, t_next)
+        e_t_prime = (e_t + e_t_next) / 2
+    elif len(old_eps) == 1:
+        # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+        e_t_prime = (3 * e_t - old_eps[-1]) / 2
+    elif len(old_eps) == 2:
+        # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+        e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+    elif len(old_eps) >= 3:
+        # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+        e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+    x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+    return x_prev, pred_x0, e_t
+
+
+def should_hijack_inpainting(checkpoint_info):
+    from modules import sd_models
+
+    ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
+    cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
+
+    return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
+
+
+def do_inpainting_hijack():
+    # p_sample_plms is needed because PLMS can't work with dicts as conditionings
+
+    ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py
new file mode 100644
index 00000000..f733e852
--- /dev/null
+++ b/modules/sd_hijack_open_clip.py
@@ -0,0 +1,37 @@
+import open_clip.tokenizer
+import torch
+
+from modules import sd_hijack_clip, devices
+from modules.shared import opts
+
+tokenizer = open_clip.tokenizer._tokenizer
+
+
+class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
+    def __init__(self, wrapped, hijack):
+        super().__init__(wrapped, hijack)
+
+        self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
+        self.id_start = tokenizer.encoder["<start_of_text>"]
+        self.id_end = tokenizer.encoder["<end_of_text>"]
+        self.id_pad = 0
+
+    def tokenize(self, texts):
+        assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
+
+        tokenized = [tokenizer.encode(text) for text in texts]
+
+        return tokenized
+
+    def encode_with_transformers(self, tokens):
+        # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
+        z = self.wrapped.encode_with_transformer(tokens)
+
+        return z
+
+    def encode_embedding_init_text(self, init_text, nvpt):
+        ids = tokenizer.encode(init_text)
+        ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
+        embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
+
+        return embedded
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 98123fbf..02c87f40 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -127,7 +127,7 @@ def check_for_psutil():
 
 invokeAI_mps_available = check_for_psutil()
 
-# -- Taken from https://github.com/invoke-ai/InvokeAI --
+# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
 if invokeAI_mps_available:
     import psutil
     mem_total_gb = psutil.virtual_memory().total // (1 << 30)
@@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size):
     return r
 
 def einsum_op_mps_v1(q, k, v):
-    if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
+    if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
         return einsum_op_compvis(q, k, v)
     else:
         slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+        if slice_size % 4096 == 0:
+            slice_size -= 1
         return einsum_op_slice_1(q, k, v, slice_size)
 
 def einsum_op_mps_v2(q, k, v):
-    if mem_total_gb > 8 and q.shape[1] <= 4096:
+    if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
         return einsum_op_compvis(q, k, v)
     else:
         return einsum_op_slice_0(q, k, v, 1)
@@ -188,7 +190,7 @@ def einsum_op(q, k, v):
         return einsum_op_cuda(q, k, v)
 
     if q.device.type == 'mps':
-        if mem_total_gb >= 32:
+        if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
             return einsum_op_mps_v1(q, k, v)
         return einsum_op_mps_v2(q, k, v)
 
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
new file mode 100644
index 00000000..18daf8c1
--- /dev/null
+++ b/modules/sd_hijack_unet.py
@@ -0,0 +1,30 @@
+import torch
+
+
+class TorchHijackForUnet:
+    """
+    This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
+    this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
+    """
+
+    def __getattr__(self, item):
+        if item == 'cat':
+            return self.cat
+
+        if hasattr(torch, item):
+            return getattr(torch, item)
+
+        raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+
+    def cat(self, tensors, *args, **kwargs):
+        if len(tensors) == 2:
+            a, b = tensors
+            if a.shape[-2:] != b.shape[-2:]:
+                a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
+
+            tensors = (a, b)
+
+        return torch.cat(tensors, *args, **kwargs)
+
+
+th = TorchHijackForUnet()
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py
new file mode 100644
index 00000000..4ac51c38
--- /dev/null
+++ b/modules/sd_hijack_xlmr.py
@@ -0,0 +1,34 @@
+import open_clip.tokenizer
+import torch
+
+from modules import sd_hijack_clip, devices
+from modules.shared import opts
+
+
+class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
+    def __init__(self, wrapped, hijack):
+        super().__init__(wrapped, hijack)
+
+        self.id_start = wrapped.config.bos_token_id
+        self.id_end = wrapped.config.eos_token_id
+        self.id_pad = wrapped.config.pad_token_id
+
+        self.comma_token = self.tokenizer.get_vocab().get(',', None)  # alt diffusion doesn't have </w> bits for comma
+
+    def encode_with_transformers(self, tokens):
+        # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
+        # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
+        # layer to work with - you have to use the last
+
+        attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
+        features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
+        z = features['projection_state']
+
+        return z
+
+    def encode_embedding_init_text(self, init_text, nvpt):
+        embedding_layer = self.wrapped.roberta.embeddings
+        ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+        embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+
+        return embedded
diff --git a/modules/sd_models.py b/modules/sd_models.py
index eae22e87..76a89e88 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,26 +1,33 @@
 import collections
 import os.path
 import sys
+import gc
 from collections import namedtuple
 import torch
+import re
+import safetensors.torch
 from omegaconf import OmegaConf
+from os import mkdir
+from urllib import request
+import ldm.modules.midas as midas
 
 from ldm.util import instantiate_from_config
 
-from modules import shared, modelloader, devices
+from modules import shared, modelloader, devices, script_callbacks, sd_vae
 from modules.paths import models_path
+from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
 
 model_dir = "Stable-diffusion"
 model_path = os.path.abspath(os.path.join(models_path, model_dir))
 
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
+CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
 checkpoints_list = {}
 checkpoints_loaded = collections.OrderedDict()
 
 try:
     # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
 
-    from transformers import logging
+    from transformers import logging, CLIPModel
 
     logging.set_verbosity_error()
 except Exception:
@@ -32,15 +39,26 @@ def setup_model():
         os.makedirs(model_path)
 
     list_models()
+    enable_midas_autodownload()
 
 
-def checkpoint_tiles():
-    return sorted([x.title for x in checkpoints_list.values()])
+def checkpoint_tiles(): 
+    convert = lambda name: int(name) if name.isdigit() else name.lower() 
+    alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] 
+    return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
+
+
+def find_checkpoint_config(info):
+    config = os.path.splitext(info.filename)[0] + ".yaml"
+    if os.path.exists(config):
+        return config
+
+    return shared.cmd_opts.config
 
 
 def list_models():
     checkpoints_list.clear()
-    model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
+    model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
 
     def modeltitle(path, shorthash):
         abspath = os.path.abspath(path)
@@ -63,7 +81,7 @@ def list_models():
     if os.path.exists(cmd_ckpt):
         h = model_hash(cmd_ckpt)
         title, short_model_name = modeltitle(cmd_ckpt, h)
-        checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
+        checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
         shared.opts.data['sd_model_checkpoint'] = title
     elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
         print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
@@ -71,12 +89,7 @@ def list_models():
         h = model_hash(filename)
         title, short_model_name = modeltitle(filename, h)
 
-        basename, _ = os.path.splitext(filename)
-        config = basename + ".yaml"
-        if not os.path.exists(config):
-            config = shared.cmd_opts.config
-
-        checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
+        checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
 
 
 def get_closet_checkpoint_match(searchString):
@@ -101,18 +114,19 @@ def model_hash(filename):
 
 def select_checkpoint():
     model_checkpoint = shared.opts.sd_model_checkpoint
+        
     checkpoint_info = checkpoints_list.get(model_checkpoint, None)
     if checkpoint_info is not None:
         return checkpoint_info
 
     if len(checkpoints_list) == 0:
-        print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
+        print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
         if shared.cmd_opts.ckpt is not None:
             print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
         print(f" - directory {model_path}", file=sys.stderr)
         if shared.cmd_opts.ckpt_dir is not None:
             print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
-        print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
+        print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
         exit(1)
 
     checkpoint_info = next(iter(checkpoints_list.values()))
@@ -138,8 +152,8 @@ def transform_checkpoint_dict_key(k):
 
 
 def get_state_dict_from_checkpoint(pl_sd):
-    if "state_dict" in pl_sd:
-        pl_sd = pl_sd["state_dict"]
+    pl_sd = pl_sd.pop("state_dict", pl_sd)
+    pl_sd.pop("state_dict", None)
 
     sd = {}
     for k, v in pl_sd.items():
@@ -154,64 +168,156 @@ def get_state_dict_from_checkpoint(pl_sd):
     return pl_sd
 
 
-def load_model_weights(model, checkpoint_info):
+def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
+    _, extension = os.path.splitext(checkpoint_file)
+    if extension.lower() == ".safetensors":
+        device = map_location or shared.weight_load_location
+        if device is None:
+            device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
+        pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+    else:
+        pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
+
+    if print_global_state and "global_step" in pl_sd:
+        print(f"Global Step: {pl_sd['global_step']}")
+
+    sd = get_state_dict_from_checkpoint(pl_sd)
+    return sd
+
+
+def load_model_weights(model, checkpoint_info, vae_file="auto"):
     checkpoint_file = checkpoint_info.filename
     sd_model_hash = checkpoint_info.hash
 
-    if checkpoint_info not in checkpoints_loaded:
+    cache_enabled = shared.opts.sd_checkpoint_cache > 0
+
+    if cache_enabled and checkpoint_info in checkpoints_loaded:
+        # use checkpoint cache
+        print(f"Loading weights [{sd_model_hash}] from cache")
+        model.load_state_dict(checkpoints_loaded[checkpoint_info])
+    else:
+        # load from file
         print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
 
-        pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
-        if "global_step" in pl_sd:
-            print(f"Global Step: {pl_sd['global_step']}")
-
-        sd = get_state_dict_from_checkpoint(pl_sd)
-        missing, extra = model.load_state_dict(sd, strict=False)
+        sd = read_state_dict(checkpoint_file)
+        model.load_state_dict(sd, strict=False)
+        del sd
+        
+        if cache_enabled:
+            # cache newly loaded model
+            checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
 
         if shared.cmd_opts.opt_channelslast:
             model.to(memory_format=torch.channels_last)
 
         if not shared.cmd_opts.no_half:
+            vae = model.first_stage_model
+
+            # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
+            if shared.cmd_opts.no_half_vae:
+                model.first_stage_model = None
+
             model.half()
+            model.first_stage_model = vae
 
         devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
         devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
 
-        vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
-
-        if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
-            vae_file = shared.cmd_opts.vae_path
-
-        if os.path.exists(vae_file):
-            print(f"Loading VAE weights from: {vae_file}")
-            vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
-            vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
-            model.first_stage_model.load_state_dict(vae_dict)
-
         model.first_stage_model.to(devices.dtype_vae)
 
-        checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
-        while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+    # clean up cache if limit is reached
+    if cache_enabled:
+        while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
             checkpoints_loaded.popitem(last=False)  # LRU
-    else:
-        print(f"Loading weights [{sd_model_hash}] from cache")
-        checkpoints_loaded.move_to_end(checkpoint_info)
-        model.load_state_dict(checkpoints_loaded[checkpoint_info])
 
     model.sd_model_hash = sd_model_hash
     model.sd_model_checkpoint = checkpoint_file
     model.sd_checkpoint_info = checkpoint_info
 
+    model.logvar = model.logvar.to(devices.device)  # fix for training
 
-def load_model():
+    sd_vae.delete_base_vae()
+    sd_vae.clear_loaded_vae()
+    vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
+    sd_vae.load_vae(model, vae_file)
+
+
+def enable_midas_autodownload():
+    """
+    Gives the ldm.modules.midas.api.load_model function automatic downloading.
+
+    When the 512-depth-ema model, and other future models like it, is loaded,
+    it calls midas.api.load_model to load the associated midas depth model.
+    This function applies a wrapper to download the model to the correct
+    location automatically.
+    """
+
+    midas_path = os.path.join(models_path, 'midas')
+
+    # stable-diffusion-stability-ai hard-codes the midas model path to
+    # a location that differs from where other scripts using this model look.
+    # HACK: Overriding the path here.
+    for k, v in midas.api.ISL_PATHS.items():
+        file_name = os.path.basename(v)
+        midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
+
+    midas_urls = {
+        "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
+        "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
+        "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
+        "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
+    }
+
+    midas.api.load_model_inner = midas.api.load_model
+
+    def load_model_wrapper(model_type):
+        path = midas.api.ISL_PATHS[model_type]
+        if not os.path.exists(path):
+            if not os.path.exists(midas_path):
+                mkdir(midas_path)
+    
+            print(f"Downloading midas model weights for {model_type} to {path}")
+            request.urlretrieve(midas_urls[model_type], path)
+            print(f"{model_type} downloaded")
+
+        return midas.api.load_model_inner(model_type)
+
+    midas.api.load_model = load_model_wrapper
+
+
+def load_model(checkpoint_info=None):
     from modules import lowvram, sd_hijack
-    checkpoint_info = select_checkpoint()
+    checkpoint_info = checkpoint_info or select_checkpoint()
+    checkpoint_config = find_checkpoint_config(checkpoint_info)
 
-    if checkpoint_info.config != shared.cmd_opts.config:
-        print(f"Loading config from: {checkpoint_info.config}")
+    if checkpoint_config != shared.cmd_opts.config:
+        print(f"Loading config from: {checkpoint_config}")
+
+    if shared.sd_model:
+        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
+        shared.sd_model = None
+        gc.collect()
+        devices.torch_gc()
+
+    sd_config = OmegaConf.load(checkpoint_config)
+    
+    if should_hijack_inpainting(checkpoint_info):
+        # Hardcoded config for now...
+        sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
+        sd_config.model.params.conditioning_key = "hybrid"
+        sd_config.model.params.unet_config.params.in_channels = 9
+        sd_config.model.params.finetune_keys = None
+
+    if not hasattr(sd_config.model.params, "use_ema"):
+        sd_config.model.params.use_ema = False
+
+    do_inpainting_hijack()
+
+    if shared.cmd_opts.no_half:
+        sd_config.model.params.unet_config.params.use_fp16 = False
 
-    sd_config = OmegaConf.load(checkpoint_info.config)
     sd_model = instantiate_from_config(sd_config.model)
+
     load_model_weights(sd_model, checkpoint_info)
 
     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -222,21 +328,34 @@ def load_model():
     sd_hijack.model_hijack.hijack(sd_model)
 
     sd_model.eval()
+    shared.sd_model = sd_model
+
+    sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model
+
+    script_callbacks.model_loaded_callback(sd_model)
+
+    print("Model loaded.")
 
-    print(f"Model loaded.")
     return sd_model
 
 
-def reload_model_weights(sd_model, info=None):
+def reload_model_weights(sd_model=None, info=None):
     from modules import lowvram, devices, sd_hijack
     checkpoint_info = info or select_checkpoint()
 
+    if not sd_model:
+        sd_model = shared.sd_model
+
+    current_checkpoint_info = sd_model.sd_checkpoint_info
+    checkpoint_config = find_checkpoint_config(current_checkpoint_info)
+
     if sd_model.sd_model_checkpoint == checkpoint_info.filename:
         return
 
-    if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
+    if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+        del sd_model
         checkpoints_loaded.clear()
-        shared.sd_model = load_model()
+        load_model(checkpoint_info)
         return shared.sd_model
 
     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -246,12 +365,19 @@ def reload_model_weights(sd_model, info=None):
 
     sd_hijack.model_hijack.undo_hijack(sd_model)
 
-    load_model_weights(sd_model, checkpoint_info)
+    try:
+        load_model_weights(sd_model, checkpoint_info)
+    except Exception as e:
+        print("Failed to load checkpoint, restoring previous")
+        load_model_weights(sd_model, current_checkpoint_info)
+        raise
+    finally:
+        sd_hijack.model_hijack.hijack(sd_model)
+        script_callbacks.model_loaded_callback(sd_model)
 
-    sd_hijack.model_hijack.hijack(sd_model)
+        if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+            sd_model.to(devices.device)
 
-    if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
-        sd_model.to(devices.device)
+    print("Weights loaded.")
 
-    print(f"Weights loaded.")
     return sd_model
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index b58e810b..e904d860 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,32 +1,41 @@
-from collections import namedtuple
+from collections import namedtuple, deque
 import numpy as np
+from math import floor
 import torch
 import tqdm
 from PIL import Image
 import inspect
 import k_diffusion.sampling
+import torchsde._brownian.brownian_interval
 import ldm.models.diffusion.ddim
 import ldm.models.diffusion.plms
-from modules import prompt_parser, devices, processing
+from modules import prompt_parser, devices, processing, images, sd_vae_approx
 
 from modules.shared import opts, cmd_opts, state
 import modules.shared as shared
+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
 
 
 SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
 
 samplers_k_diffusion = [
-    ('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
+    ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
     ('Euler', 'sample_euler', ['k_euler'], {}),
     ('LMS', 'sample_lms', ['k_lms'], {}),
     ('Heun', 'sample_heun', ['k_heun'], {}),
-    ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
-    ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
+    ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
+    ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
+    ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+    ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
+    ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
     ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
     ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
     ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
-    ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
-    ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
+    ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
+    ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
+    ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+    ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
+    ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
 ]
 
 samplers_data_k_diffusion = [
@@ -40,16 +49,24 @@ all_samplers = [
     SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
     SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
 ]
+all_samplers_map = {x.name: x for x in all_samplers}
 
 samplers = []
 samplers_for_img2img = []
+samplers_map = {}
 
 
-def create_sampler_with_index(list_of_configs, index, model):
-    config = list_of_configs[index]
+def create_sampler(name, model):
+    if name is not None:
+        config = all_samplers_map.get(name, None)
+    else:
+        config = all_samplers[0]
+
+    assert config is not None, f'bad sampler name: {name}'
+
     sampler = config.constructor(model)
     sampler.config = config
-    
+
     return sampler
 
 
@@ -62,6 +79,12 @@ def set_samplers():
     samplers = [x for x in all_samplers if x.name not in hidden]
     samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
 
+    samplers_map.clear()
+    for sampler in all_samplers:
+        samplers_map[sampler.name.lower()] = sampler.name
+        for alias in sampler.aliases:
+            samplers_map[alias.lower()] = sampler.name
+
 
 set_samplers()
 
@@ -71,6 +94,7 @@ sampler_extra_params = {
     'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
 }
 
+
 def setup_img2img_steps(p, steps=None):
     if opts.img2img_fix_steps or steps is not None:
         steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
@@ -82,14 +106,34 @@ def setup_img2img_steps(p, steps=None):
     return steps, t_enc
 
 
-def sample_to_image(samples):
-    x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+
+
+def single_sample_to_image(sample, approximation=None):
+    if approximation is None:
+        approximation = approximation_indexes.get(opts.show_progress_type, 0)
+
+    if approximation == 2:
+        x_sample = sd_vae_approx.cheap_approximation(sample)
+    elif approximation == 1:
+        x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+    else:
+        x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+
     x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
     x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
     x_sample = x_sample.astype(np.uint8)
     return Image.fromarray(x_sample)
 
 
+def sample_to_image(samples, index=0, approximation=None):
+    return single_sample_to_image(samples[index], approximation)
+
+
+def samples_to_image_grid(samples, approximation=None):
+    return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
+
+
 def store_latent(decoded):
     state.current_latent = decoded
 
@@ -105,7 +149,8 @@ class InterruptedException(BaseException):
 class VanillaStableDiffusionSampler:
     def __init__(self, constructor, sd_model):
         self.sampler = constructor(sd_model)
-        self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
+        self.is_plms = hasattr(self.sampler, 'p_sample_plms')
+        self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
         self.mask = None
         self.nmask = None
         self.init_latent = None
@@ -117,6 +162,8 @@ class VanillaStableDiffusionSampler:
         self.config = None
         self.last_latent = None
 
+        self.conditioning_key = sd_model.model.conditioning_key
+
     def number_of_needed_noises(self, p):
         return 0
 
@@ -136,6 +183,12 @@ class VanillaStableDiffusionSampler:
         if self.stop_at is not None and self.step > self.stop_at:
             raise InterruptedException
 
+        # Have to unwrap the inpainting conditioning here to perform pre-processing
+        image_conditioning = None
+        if isinstance(cond, dict):
+            image_conditioning = cond["c_concat"][0]
+            cond = cond["c_crossattn"][0]
+            unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
 
         conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
         unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
@@ -157,6 +210,12 @@ class VanillaStableDiffusionSampler:
             img_orig = self.sampler.model.q_sample(self.init_latent, ts)
             x_dec = img_orig * self.mask + self.nmask * x_dec
 
+        # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
+        # Note that they need to be lists because it just concatenates them later.
+        if image_conditioning is not None:
+            cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
+            unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
         res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
 
         if self.mask is not None:
@@ -182,39 +241,52 @@ class VanillaStableDiffusionSampler:
         self.mask = p.mask if hasattr(p, 'mask') else None
         self.nmask = p.nmask if hasattr(p, 'nmask') else None
 
-    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
-        steps, t_enc = setup_img2img_steps(p, steps)
+    def adjust_steps_if_invalid(self, p, num_steps):
+        if  (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+            valid_step = 999 / (1000 // num_steps)
+            if valid_step == floor(valid_step):
+                return int(valid_step) + 1
+        
+        return num_steps
 
+    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+        steps, t_enc = setup_img2img_steps(p, steps)
+        steps = self.adjust_steps_if_invalid(p, steps)
         self.initialize(p)
 
-        # existing code fails with certain step counts, like 9
-        try:
-            self.sampler.make_schedule(ddim_num_steps=steps,  ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
-        except Exception:
-            self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
-
+        self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
         x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
 
         self.init_latent = x
+        self.last_latent = x
         self.step = 0
 
-        samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
+        # Wrap the conditioning models with additional image conditioning for inpainting model
+        if image_conditioning is not None:
+            conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+            unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+            
+            
+        samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
 
         return samples
 
-    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
         self.initialize(p)
 
         self.init_latent = None
+        self.last_latent = x
         self.step = 0
 
-        steps = steps or p.steps
+        steps = self.adjust_steps_if_invalid(p, steps or p.steps)
 
-        # existing code fails with certain step counts, like 9
-        try:
-            samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
-        except Exception:
-            samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
+        # Wrap the conditioning models with additional image conditioning for inpainting model
+        # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
+        if image_conditioning is not None:
+            conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
+            unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
+
+        samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
 
         return samples_ddim
 
@@ -228,7 +300,17 @@ class CFGDenoiser(torch.nn.Module):
         self.init_latent = None
         self.step = 0
 
-    def forward(self, x, sigma, uncond, cond, cond_scale):
+    def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
+        denoised_uncond = x_out[-uncond.shape[0]:]
+        denoised = torch.clone(denoised_uncond)
+
+        for i, conds in enumerate(conds_list):
+            for cond_index, weight in conds:
+                denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+
+        return denoised
+
+    def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
         if state.interrupted or state.skipped:
             raise InterruptedException
 
@@ -239,35 +321,37 @@ class CFGDenoiser(torch.nn.Module):
         repeats = [len(conds_list[i]) for i in range(batch_size)]
 
         x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+        image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
         sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
 
+        denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
+        cfg_denoiser_callback(denoiser_params)
+        x_in = denoiser_params.x
+        image_cond_in = denoiser_params.image_cond
+        sigma_in = denoiser_params.sigma
+
         if tensor.shape[1] == uncond.shape[1]:
             cond_in = torch.cat([tensor, uncond])
 
             if shared.batch_cond_uncond:
-                x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
+                x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
             else:
                 x_out = torch.zeros_like(x_in)
                 for batch_offset in range(0, x_out.shape[0], batch_size):
                     a = batch_offset
                     b = a + batch_size
-                    x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
+                    x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
         else:
             x_out = torch.zeros_like(x_in)
             batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
             for batch_offset in range(0, tensor.shape[0], batch_size):
                 a = batch_offset
                 b = min(a + batch_size, tensor.shape[0])
-                x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
+                x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
 
-            x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
+            x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
 
-        denoised_uncond = x_out[-uncond.shape[0]:]
-        denoised = torch.clone(denoised_uncond)
-
-        for i, conds in enumerate(conds_list):
-            for cond_index, weight in conds:
-                denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+        denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
 
         if self.mask is not None:
             denoised = self.init_latent * self.mask + self.nmask * denoised
@@ -278,34 +362,63 @@ class CFGDenoiser(torch.nn.Module):
 
 
 class TorchHijack:
-    def __init__(self, kdiff_sampler):
-        self.kdiff_sampler = kdiff_sampler
+    def __init__(self, sampler_noises):
+        # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
+        # implementation.
+        self.sampler_noises = deque(sampler_noises)
 
     def __getattr__(self, item):
         if item == 'randn_like':
-            return self.kdiff_sampler.randn_like
+            return self.randn_like
 
         if hasattr(torch, item):
             return getattr(torch, item)
 
         raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
 
+    def randn_like(self, x):
+        if self.sampler_noises:
+            noise = self.sampler_noises.popleft()
+            if noise.shape == x.shape:
+                return noise
+
+        if x.device.type == 'mps':
+            return torch.randn_like(x, device=devices.cpu).to(x.device)
+        else:
+            return torch.randn_like(x)
+
+
+# MPS fix for randn in torchsde
+def torchsde_randn(size, dtype, device, seed):
+    if device.type == 'mps':
+        generator = torch.Generator(devices.cpu).manual_seed(int(seed))
+        return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+    else:
+        generator = torch.Generator(device).manual_seed(int(seed))
+        return torch.randn(size, dtype=dtype, device=device, generator=generator)
+
+
+torchsde._brownian.brownian_interval._randn = torchsde_randn
+
 
 class KDiffusionSampler:
     def __init__(self, funcname, sd_model):
-        self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
+        denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
+
+        self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
         self.funcname = funcname
         self.func = getattr(k_diffusion.sampling, self.funcname)
         self.extra_params = sampler_extra_params.get(funcname, [])
         self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
         self.sampler_noises = None
-        self.sampler_noise_index = 0
         self.stop_at = None
         self.eta = None
         self.default_eta = 1.0
         self.config = None
         self.last_latent = None
 
+        self.conditioning_key = sd_model.model.conditioning_key
+
     def callback_state(self, d):
         step = d['i']
         latent = d["denoised"]
@@ -330,26 +443,13 @@ class KDiffusionSampler:
     def number_of_needed_noises(self, p):
         return p.steps
 
-    def randn_like(self, x):
-        noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
-
-        if noise is not None and x.shape == noise.shape:
-            res = noise
-        else:
-            res = torch.randn_like(x)
-
-        self.sampler_noise_index += 1
-        return res
-
     def initialize(self, p):
         self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
         self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
         self.model_wrap.step = 0
-        self.sampler_noise_index = 0
         self.eta = p.eta or opts.eta_ancestral
 
-        if self.sampler_noises is not None:
-            k_diffusion.sampling.torch = TorchHijack(self)
+        k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
 
         extra_params_kwargs = {}
         for param_name in self.extra_params:
@@ -361,16 +461,26 @@ class KDiffusionSampler:
 
         return extra_params_kwargs
 
-    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
-        steps, t_enc = setup_img2img_steps(p, steps)
-
+    def get_sigmas(self, p, steps):
         if p.sampler_noise_scheduler_override:
             sigmas = p.sampler_noise_scheduler_override(steps)
         elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
-            sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
+            sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+
+            sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
         else:
             sigmas = self.model_wrap.get_sigmas(steps)
 
+        if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False):
+            sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+
+        return sigmas
+
+    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+        steps, t_enc = setup_img2img_steps(p, steps)
+
+        sigmas = self.get_sigmas(p, steps)
+
         sigma_sched = sigmas[steps - t_enc - 1:]
         xi = x + noise * sigma_sched[0]
         
@@ -388,20 +498,21 @@ class KDiffusionSampler:
             extra_params_kwargs['sigmas'] = sigma_sched
 
         self.model_wrap_cfg.init_latent = x
+        self.last_latent = x
 
-        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+        samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+            'cond': conditioning, 
+            'image_cond': image_conditioning, 
+            'uncond': unconditional_conditioning, 
+            'cond_scale': p.cfg_scale
+        }, disable=False, callback=self.callback_state, **extra_params_kwargs))
 
         return samples
 
-    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
         steps = steps or p.steps
 
-        if p.sampler_noise_scheduler_override:
-            sigmas = p.sampler_noise_scheduler_override(steps)
-        elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
-            sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
-        else:
-            sigmas = self.model_wrap.get_sigmas(steps)
+        sigmas = self.get_sigmas(p, steps)
 
         x = x * sigmas[0]
 
@@ -414,7 +525,13 @@ class KDiffusionSampler:
         else:
             extra_params_kwargs['sigmas'] = sigmas
 
-        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+        self.last_latent = x
+        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+            'cond': conditioning, 
+            'image_cond': image_conditioning, 
+            'uncond': unconditional_conditioning, 
+            'cond_scale': p.cfg_scale
+        }, disable=False, callback=self.callback_state, **extra_params_kwargs))
 
         return samples
 
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
new file mode 100644
index 00000000..ac71d62d
--- /dev/null
+++ b/modules/sd_vae.py
@@ -0,0 +1,231 @@
+import torch
+import os
+import collections
+from collections import namedtuple
+from modules import shared, devices, script_callbacks
+from modules.paths import models_path
+import glob
+from copy import deepcopy
+
+
+model_dir = "Stable-diffusion"
+model_path = os.path.abspath(os.path.join(models_path, model_dir))
+vae_dir = "VAE"
+vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
+
+
+vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
+
+
+default_vae_dict = {"auto": "auto", "None": None, None: None}
+default_vae_list = ["auto", "None"]
+
+
+default_vae_values = [default_vae_dict[x] for x in default_vae_list]
+vae_dict = dict(default_vae_dict)
+vae_list = list(default_vae_list)
+first_load = True
+
+
+base_vae = None
+loaded_vae_file = None
+checkpoint_info = None
+
+checkpoints_loaded = collections.OrderedDict()
+
+def get_base_vae(model):
+    if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
+        return base_vae
+    return None
+
+
+def store_base_vae(model):
+    global base_vae, checkpoint_info
+    if checkpoint_info != model.sd_checkpoint_info:
+        assert not loaded_vae_file, "Trying to store non-base VAE!"
+        base_vae = deepcopy(model.first_stage_model.state_dict())
+        checkpoint_info = model.sd_checkpoint_info
+
+
+def delete_base_vae():
+    global base_vae, checkpoint_info
+    base_vae = None
+    checkpoint_info = None
+
+
+def restore_base_vae(model):
+    global loaded_vae_file
+    if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
+        print("Restoring base VAE")
+        _load_vae_dict(model, base_vae)
+        loaded_vae_file = None
+    delete_base_vae()
+
+
+def get_filename(filepath):
+    return os.path.splitext(os.path.basename(filepath))[0]
+
+
+def refresh_vae_list(vae_path=vae_path, model_path=model_path):
+    global vae_dict, vae_list
+    res = {}
+    candidates = [
+        *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
+        *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
+        *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
+        *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
+    ]
+    if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
+        candidates.append(shared.cmd_opts.vae_path)
+    for filepath in candidates:
+        name = get_filename(filepath)
+        res[name] = filepath
+    vae_list.clear()
+    vae_list.extend(default_vae_list)
+    vae_list.extend(list(res.keys()))
+    vae_dict.clear()
+    vae_dict.update(res)
+    vae_dict.update(default_vae_dict)
+    return vae_list
+
+
+def get_vae_from_settings(vae_file="auto"):
+    # else, we load from settings, if not set to be default
+    if vae_file == "auto" and shared.opts.sd_vae is not None:
+        # if saved VAE settings isn't recognized, fallback to auto
+        vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
+        # if VAE selected but not found, fallback to auto
+        if vae_file not in default_vae_values and not os.path.isfile(vae_file):
+            vae_file = "auto"
+            print(f"Selected VAE doesn't exist: {vae_file}")
+    return vae_file
+
+
+def resolve_vae(checkpoint_file=None, vae_file="auto"):
+    global first_load, vae_dict, vae_list
+
+    # if vae_file argument is provided, it takes priority, but not saved
+    if vae_file and vae_file not in default_vae_list:
+        if not os.path.isfile(vae_file):
+            print(f"VAE provided as function argument doesn't exist: {vae_file}")
+            vae_file = "auto"
+    # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
+    if first_load and shared.cmd_opts.vae_path is not None:
+        if os.path.isfile(shared.cmd_opts.vae_path):
+            vae_file = shared.cmd_opts.vae_path
+            shared.opts.data['sd_vae'] = get_filename(vae_file)
+        else:
+            print(f"VAE provided as command line argument doesn't exist: {vae_file}")
+    # fallback to selector in settings, if vae selector not set to act as default fallback
+    if not shared.opts.sd_vae_as_default:
+        vae_file = get_vae_from_settings(vae_file)
+    # vae-path cmd arg takes priority for auto
+    if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
+        if os.path.isfile(shared.cmd_opts.vae_path):
+            vae_file = shared.cmd_opts.vae_path
+            print(f"Using VAE provided as command line argument: {vae_file}")
+    # if still not found, try look for ".vae.pt" beside model
+    model_path = os.path.splitext(checkpoint_file)[0]
+    if vae_file == "auto":
+        vae_file_try = model_path + ".vae.pt"
+        if os.path.isfile(vae_file_try):
+            vae_file = vae_file_try
+            print(f"Using VAE found similar to selected model: {vae_file}")
+    # if still not found, try look for ".vae.ckpt" beside model
+    if vae_file == "auto":
+        vae_file_try = model_path + ".vae.ckpt"
+        if os.path.isfile(vae_file_try):
+            vae_file = vae_file_try
+            print(f"Using VAE found similar to selected model: {vae_file}")
+    # No more fallbacks for auto
+    if vae_file == "auto":
+        vae_file = None
+    # Last check, just because
+    if vae_file and not os.path.exists(vae_file):
+        vae_file = None
+
+    return vae_file
+
+
+def load_vae(model, vae_file=None):
+    global first_load, vae_dict, vae_list, loaded_vae_file
+    # save_settings = False
+
+    cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
+
+    if vae_file:
+        if cache_enabled and vae_file in checkpoints_loaded:
+            # use vae checkpoint cache
+            print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
+            store_base_vae(model)
+            _load_vae_dict(model, checkpoints_loaded[vae_file])
+        else:
+            assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
+            print(f"Loading VAE weights from: {vae_file}")
+            store_base_vae(model)
+            vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
+            vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+            _load_vae_dict(model, vae_dict_1)
+
+            if cache_enabled:
+                # cache newly loaded vae
+                checkpoints_loaded[vae_file] = vae_dict_1.copy()
+
+        # clean up cache if limit is reached
+        if cache_enabled:
+            while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
+                checkpoints_loaded.popitem(last=False)  # LRU
+
+        # If vae used is not in dict, update it
+        # It will be removed on refresh though
+        vae_opt = get_filename(vae_file)
+        if vae_opt not in vae_dict:
+            vae_dict[vae_opt] = vae_file
+            vae_list.append(vae_opt)
+    elif loaded_vae_file:
+        restore_base_vae(model)
+
+    loaded_vae_file = vae_file
+
+    first_load = False
+
+
+# don't call this from outside
+def _load_vae_dict(model, vae_dict_1):
+    model.first_stage_model.load_state_dict(vae_dict_1)
+    model.first_stage_model.to(devices.dtype_vae)
+
+def clear_loaded_vae():
+    global loaded_vae_file
+    loaded_vae_file = None
+
+def reload_vae_weights(sd_model=None, vae_file="auto"):
+    from modules import lowvram, devices, sd_hijack
+
+    if not sd_model:
+        sd_model = shared.sd_model
+
+    checkpoint_info = sd_model.sd_checkpoint_info
+    checkpoint_file = checkpoint_info.filename
+    vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
+
+    if loaded_vae_file == vae_file:
+        return
+
+    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+        lowvram.send_everything_to_cpu()
+    else:
+        sd_model.to(devices.cpu)
+
+    sd_hijack.model_hijack.undo_hijack(sd_model)
+
+    load_vae(sd_model, vae_file)
+
+    sd_hijack.model_hijack.hijack(sd_model)
+    script_callbacks.model_loaded_callback(sd_model)
+
+    if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+        sd_model.to(devices.device)
+
+    print("VAE Weights loaded.")
+    return sd_model
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
new file mode 100644
index 00000000..0a58542d
--- /dev/null
+++ b/modules/sd_vae_approx.py
@@ -0,0 +1,58 @@
+import os
+
+import torch
+from torch import nn
+from modules import devices, paths
+
+sd_vae_approx_model = None
+
+
+class VAEApprox(nn.Module):
+    def __init__(self):
+        super(VAEApprox, self).__init__()
+        self.conv1 = nn.Conv2d(4, 8, (7, 7))
+        self.conv2 = nn.Conv2d(8, 16, (5, 5))
+        self.conv3 = nn.Conv2d(16, 32, (3, 3))
+        self.conv4 = nn.Conv2d(32, 64, (3, 3))
+        self.conv5 = nn.Conv2d(64, 32, (3, 3))
+        self.conv6 = nn.Conv2d(32, 16, (3, 3))
+        self.conv7 = nn.Conv2d(16, 8, (3, 3))
+        self.conv8 = nn.Conv2d(8, 3, (3, 3))
+
+    def forward(self, x):
+        extra = 11
+        x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
+        x = nn.functional.pad(x, (extra, extra, extra, extra))
+
+        for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
+            x = layer(x)
+            x = nn.functional.leaky_relu(x, 0.1)
+
+        return x
+
+
+def model():
+    global sd_vae_approx_model
+
+    if sd_vae_approx_model is None:
+        sd_vae_approx_model = VAEApprox()
+        sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
+        sd_vae_approx_model.eval()
+        sd_vae_approx_model.to(devices.device, devices.dtype)
+
+    return sd_vae_approx_model
+
+
+def cheap_approximation(sample):
+    # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
+
+    coefs = torch.tensor([
+        [0.298, 0.207, 0.208],
+        [0.187, 0.286, 0.173],
+        [-0.158, 0.189, 0.264],
+        [-0.184, -0.271, -0.473],
+    ]).to(sample.device)
+
+    x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
+
+    return x_sample
diff --git a/modules/shared.py b/modules/shared.py
index faede821..54a6ba23 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -3,24 +3,27 @@ import datetime
 import json
 import os
 import sys
+import time
 
+from PIL import Image
 import gradio as gr
 import tqdm
 
 import modules.artists
 import modules.interrogate
 import modules.memmon
-import modules.sd_models
 import modules.styles
 import modules.devices as devices
-from modules import sd_samplers, sd_models, localization
-from modules.hypernetworks import hypernetwork
+from modules import localization, sd_vae, extensions, script_loading, errors
 from modules.paths import models_path, script_path, sd_path
 
+
+demo = None
+
 sd_model_file = os.path.join(script_path, 'model.ckpt')
 default_sd_model_file = sd_model_file
 parser = argparse.ArgumentParser()
-parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
+parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
 parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
 parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
 parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
@@ -39,34 +42,35 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
 parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
 parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
 parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
-parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
+parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
 parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
 parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
+parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
 parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
 parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
 parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
 parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
 parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
-parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
-parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
-parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
+parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
 parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
 parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
-parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
+parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
 parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
 parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
 parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
 parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
 parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
 parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
 parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
 parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
 parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
+parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
 parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
 parser.add_argument("--gradio-debug",  action='store_true', help="launch gradio with --debug option")
 parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
 parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
+parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
 parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
 parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
 parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
@@ -76,12 +80,27 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
 parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
 parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
 parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
-parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
-parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
+parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
+parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
+parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
+parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
+parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
+parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
+parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
+parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
+parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
+parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
+parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
+
+script_loading.preload_extensions(extensions.extensions_dir, parser)
+script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
 
 cmd_opts = parser.parse_args()
-restricted_opts = [
+
+restricted_opts = {
     "samples_filename_pattern",
+    "directories_filename_pattern",
     "outdir_samples",
     "outdir_txt2img_samples",
     "outdir_img2img_samples",
@@ -89,10 +108,23 @@ restricted_opts = [
     "outdir_grids",
     "outdir_txt2img_grids",
     "outdir_save",
+}
+
+ui_reorder_categories = [
+    "sampler",
+    "dimensions",
+    "cfg",
+    "seed",
+    "checkboxes",
+    "hires_fix",
+    "batch",
+    "scripts",
 ]
 
-devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
+cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
+
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
+    (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
 
 device = devices.device
 weight_load_location = None if cmd_opts.lowram else "cpu"
@@ -103,10 +135,12 @@ xformers_available = False
 config_filename = cmd_opts.ui_settings_file
 
 os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
-hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
+hypernetworks = {}
 loaded_hypernetwork = None
 
+
 def reload_hypernetworks():
+    from modules.hypernetworks import hypernetwork
     global hypernetworks
 
     hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
@@ -126,6 +160,8 @@ class State:
     current_image = None
     current_image_sampling_step = 0
     textinfo = None
+    time_start = None
+    need_restart = False
 
     def skip(self):
         self.skipped = True
@@ -134,12 +170,67 @@ class State:
         self.interrupted = True
 
     def nextjob(self):
+        if opts.show_progress_every_n_steps == -1:
+            self.do_set_current_image()
+
         self.job_no += 1
         self.sampling_step = 0
         self.current_image_sampling_step = 0
 
-    def get_job_timestamp(self):
-        return datetime.datetime.now().strftime("%Y%m%d%H%M%S")  # shouldn't this return job_timestamp?
+    def dict(self):
+        obj = {
+            "skipped": self.skipped,
+            "interrupted": self.interrupted,
+            "job": self.job,
+            "job_count": self.job_count,
+            "job_timestamp": self.job_timestamp,
+            "job_no": self.job_no,
+            "sampling_step": self.sampling_step,
+            "sampling_steps": self.sampling_steps,
+        }
+
+        return obj
+
+    def begin(self):
+        self.sampling_step = 0
+        self.job_count = -1
+        self.job_no = 0
+        self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+        self.current_latent = None
+        self.current_image = None
+        self.current_image_sampling_step = 0
+        self.skipped = False
+        self.interrupted = False
+        self.textinfo = None
+        self.time_start = time.time()
+
+        devices.torch_gc()
+
+    def end(self):
+        self.job = ""
+        self.job_count = 0
+
+        devices.torch_gc()
+
+    """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
+    def set_current_image(self):
+        if not parallel_processing_allowed:
+            return
+
+        if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
+            self.do_set_current_image()
+
+    def do_set_current_image(self):
+        if self.current_latent is None:
+            return
+
+        import modules.sd_samplers
+        if opts.show_progress_grid:
+            self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
+        else:
+            self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
+
+        self.current_image_sampling_step = self.sampling_step
 
 
 state = State()
@@ -153,8 +244,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
 
 face_restorers = []
 
-localization.list_localizations(cmd_opts.localizations_dir)
-
 
 def realesrgan_models_names():
     import modules.realesrgan_model
@@ -162,13 +251,13 @@ def realesrgan_models_names():
 
 
 class OptionInfo:
-    def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
+    def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
         self.default = default
         self.label = label
         self.component = component
         self.component_args = component_args
         self.onchange = onchange
-        self.section = None
+        self.section = section
         self.refresh = refresh
 
 
@@ -179,6 +268,21 @@ def options_section(section_identifier, options_dict):
     return options_dict
 
 
+def list_checkpoint_tiles():
+    import modules.sd_models
+    return modules.sd_models.checkpoint_tiles()
+
+
+def refresh_checkpoints():
+    import modules.sd_models
+    return modules.sd_models.list_models()
+
+
+def list_samplers():
+    import modules.sd_samplers
+    return modules.sd_samplers.all_samplers
+
+
 hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
 
 options_templates = {}
@@ -186,7 +290,8 @@ options_templates = {}
 options_templates.update(options_section(('saving-images', "Saving images/grids"), {
     "samples_save": OptionInfo(True, "Always save all generated images"),
     "samples_format": OptionInfo('png', 'File format for images'),
-    "samples_filename_pattern": OptionInfo("", "Images filename pattern"),
+    "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
+    "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
 
     "grid_save": OptionInfo(True, "Always save all generated image grids"),
     "grid_format": OptionInfo('png', 'File format for grids'),
@@ -198,12 +303,19 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
     "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
     "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
     "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
+    "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
+    "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
     "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
     "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
 
     "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
+    "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
     "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
     "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
+
+    "temp_dir":  OptionInfo("", "Directory for temporary images; leave empty for default"),
+    "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
+
 }))
 
 options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@@ -221,19 +333,15 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
     "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
     "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
     "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
-    "directories_filename_pattern": OptionInfo("", "Directory name pattern"),
-    "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
+    "directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs),
+    "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
 }))
 
 options_templates.update(options_section(('upscaling', "Upscaling"), {
     "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
     "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
-    "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
-    "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
-    "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
-    "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
+    "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
     "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
-    "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
 }))
 
 options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -249,31 +357,42 @@ options_templates.update(options_section(('system', "System"), {
 }))
 
 options_templates.update(options_section(('training', "Training"), {
-    "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
+    "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
+    "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
+    "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
     "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
     "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
     "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
     "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
+    "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
 }))
 
 options_templates.update(options_section(('sd', "Stable Diffusion"), {
-    "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
+    "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
     "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+    "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+    "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
+    "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
     "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
     "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
+    "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+    "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
     "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
-    "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
     "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
+    "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
     "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
     "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
-    "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
     "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
     "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
-    "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
-    'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
+    'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
     "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
 }))
 
+options_templates.update(options_section(('compatibility', "Compatibility"), {
+    "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
+    "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
+}))
+
 options_templates.update(options_section(('interrogate', "Interrogate Options"), {
     "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
     "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
@@ -286,26 +405,34 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
     "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
     "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
     "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
+    "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
 }))
 
 options_templates.update(options_section(('ui', "User interface"), {
     "show_progressbar": OptionInfo(True, "Show progressbar"),
-    "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
+    "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
+    "show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
+    "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
     "return_grid": OptionInfo(True, "Show grid in results for web"),
     "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
     "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
     "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
     "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
+    "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
+    "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
     "font": OptionInfo("", "Font for image grids that have text"),
     "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
     "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
     "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
+    "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
+    "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
     'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
+    'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"),
     'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
 }))
 
 options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
-    "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
+    "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
     "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
     "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
     "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
@@ -315,6 +442,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
     'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
 }))
 
+options_templates.update(options_section((None, "Hidden options"), {
+    "disabled_extensions": OptionInfo([], "Disable those extensions"),
+}))
+
+options_templates.update()
+
 
 class Options:
     data = None
@@ -326,8 +459,19 @@ class Options:
 
     def __setattr__(self, key, value):
         if self.data is not None:
-            if key in self.data:
+            if key in self.data or key in self.data_labels:
+                assert not cmd_opts.freeze_settings, "changing settings is disabled"
+
+                info = opts.data_labels.get(key, None)
+                comp_args = info.component_args if info else None
+                if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
+                    raise RuntimeError(f"not possible to set {key} because it is restricted")
+
+                if cmd_opts.hide_ui_dir_config and key in restricted_opts:
+                    raise RuntimeError(f"not possible to set {key} because it is restricted")
+
                 self.data[key] = value
+                return
 
         return super(Options, self).__setattr__(key, value)
 
@@ -341,9 +485,33 @@ class Options:
 
         return super(Options, self).__getattribute__(item)
 
+    def set(self, key, value):
+        """sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
+
+        oldval = self.data.get(key, None)
+        if oldval == value:
+            return False
+
+        try:
+            setattr(self, key, value)
+        except RuntimeError:
+            return False
+
+        if self.data_labels[key].onchange is not None:
+            try:
+                self.data_labels[key].onchange()
+            except Exception as e:
+                errors.display(e, f"changing setting {key} to {value}")
+                setattr(self, key, oldval)
+                return False
+
+        return True
+
     def save(self, filename):
+        assert not cmd_opts.freeze_settings, "saving settings is disabled"
+
         with open(filename, "w", encoding="utf8") as file:
-            json.dump(self.data, file)
+            json.dump(self.data, file, indent=4)
 
     def same_type(self, x, y):
         if x is None or y is None:
@@ -368,25 +536,51 @@ class Options:
         if bad_settings > 0:
             print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
 
-    def onchange(self, key, func):
+    def onchange(self, key, func, call=True):
         item = self.data_labels.get(key)
         item.onchange = func
 
-        func()
+        if call:
+            func()
 
     def dumpjson(self):
         d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
         return json.dumps(d)
 
+    def add_option(self, key, info):
+        self.data_labels[key] = info
+
+    def reorder(self):
+        """reorder settings so that all items related to section always go together"""
+
+        section_ids = {}
+        settings_items = self.data_labels.items()
+        for k, item in settings_items:
+            if item.section not in section_ids:
+                section_ids[item.section] = len(section_ids)
+
+        self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+
 
 opts = Options()
 if os.path.exists(config_filename):
     opts.load(config_filename)
 
+latent_upscale_default_mode = "Latent"
+latent_upscale_modes = {
+    "Latent": {"mode": "bilinear", "antialias": False},
+    "Latent (antialiased)": {"mode": "bilinear", "antialias": True},
+    "Latent (bicubic)": {"mode": "bicubic", "antialias": False},
+    "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
+    "Latent (nearest)": {"mode": "nearest", "antialias": False},
+}
+
 sd_upscalers = []
 
 sd_model = None
 
+clip_model = None
+
 progress_print_out = sys.stdout
 
 
@@ -426,3 +620,8 @@ total_tqdm = TotalTQDM()
 
 mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
 mem_mon.start()
+
+
+def listfiles(dirname):
+    filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
+    return [file for file in filenames if os.path.isfile(file)]
diff --git a/modules/styles.py b/modules/styles.py
index 3bf5c5b6..ce6e71ca 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -65,17 +65,6 @@ class StyleDatabase:
     def apply_negative_styles_to_prompt(self, prompt, styles):
         return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
 
-    def apply_styles(self, p: StableDiffusionProcessing) -> None:
-        if isinstance(p.prompt, list):
-            p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
-        else:
-            p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
-
-        if isinstance(p.negative_prompt, list):
-            p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
-        else:
-            p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
-
     def save_styles(self, path: str) -> None:
         # Write to temporary file first, so we don't nuke the file if something goes wrong
         fd, temp_path = tempfile.mkstemp(".csv")
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
new file mode 100644
index 00000000..68e1103c
--- /dev/null
+++ b/modules/textual_inversion/autocrop.py
@@ -0,0 +1,341 @@
+import cv2
+import requests
+import os
+from collections import defaultdict
+from math import log, sqrt
+import numpy as np
+from PIL import Image, ImageDraw
+
+GREEN = "#0F0"
+BLUE = "#00F"
+RED = "#F00"
+
+
+def crop_image(im, settings):
+  """ Intelligently crop an image to the subject matter """
+
+  scale_by = 1
+  if is_landscape(im.width, im.height):
+    scale_by = settings.crop_height / im.height
+  elif is_portrait(im.width, im.height):
+    scale_by = settings.crop_width / im.width
+  elif is_square(im.width, im.height):
+    if is_square(settings.crop_width, settings.crop_height):
+      scale_by = settings.crop_width / im.width
+    elif is_landscape(settings.crop_width, settings.crop_height):
+      scale_by = settings.crop_width / im.width
+    elif is_portrait(settings.crop_width, settings.crop_height):
+      scale_by = settings.crop_height / im.height
+
+  im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+  im_debug = im.copy()
+
+  focus = focal_point(im_debug, settings)
+
+  # take the focal point and turn it into crop coordinates that try to center over the focal
+  # point but then get adjusted back into the frame
+  y_half = int(settings.crop_height / 2)
+  x_half = int(settings.crop_width / 2)
+
+  x1 = focus.x - x_half
+  if x1 < 0:
+      x1 = 0
+  elif x1 + settings.crop_width > im.width:
+      x1 = im.width - settings.crop_width
+
+  y1 = focus.y - y_half
+  if y1 < 0:
+      y1 = 0
+  elif y1 + settings.crop_height > im.height:
+      y1 = im.height - settings.crop_height
+
+  x2 = x1 + settings.crop_width
+  y2 = y1 + settings.crop_height
+
+  crop = [x1, y1, x2, y2]
+
+  results = []
+
+  results.append(im.crop(tuple(crop)))
+
+  if settings.annotate_image:
+    d = ImageDraw.Draw(im_debug)
+    rect = list(crop)
+    rect[2] -= 1
+    rect[3] -= 1
+    d.rectangle(rect, outline=GREEN)
+    results.append(im_debug)
+    if settings.destop_view_image:
+      im_debug.show()
+
+  return results
+
+def focal_point(im, settings):
+    corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
+    entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
+    face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
+
+    pois = []
+
+    weight_pref_total = 0
+    if len(corner_points) > 0:
+      weight_pref_total += settings.corner_points_weight
+    if len(entropy_points) > 0:
+      weight_pref_total += settings.entropy_points_weight
+    if len(face_points) > 0:
+      weight_pref_total += settings.face_points_weight
+
+    corner_centroid = None
+    if len(corner_points) > 0:
+      corner_centroid = centroid(corner_points)
+      corner_centroid.weight = settings.corner_points_weight / weight_pref_total 
+      pois.append(corner_centroid)
+
+    entropy_centroid = None
+    if len(entropy_points) > 0:
+      entropy_centroid = centroid(entropy_points)
+      entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
+      pois.append(entropy_centroid)
+
+    face_centroid = None
+    if len(face_points) > 0:
+      face_centroid = centroid(face_points)
+      face_centroid.weight = settings.face_points_weight / weight_pref_total 
+      pois.append(face_centroid)
+
+    average_point = poi_average(pois, settings)
+
+    if settings.annotate_image:
+      d = ImageDraw.Draw(im)
+      max_size = min(im.width, im.height) * 0.07
+      if corner_centroid is not None:
+        color = BLUE
+        box = corner_centroid.bounding(max_size * corner_centroid.weight)
+        d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
+        d.ellipse(box, outline=color)
+        if len(corner_points) > 1:
+          for f in corner_points:
+            d.rectangle(f.bounding(4), outline=color)
+      if entropy_centroid is not None:
+        color = "#ff0"
+        box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
+        d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
+        d.ellipse(box, outline=color)
+        if len(entropy_points) > 1:
+          for f in entropy_points:
+            d.rectangle(f.bounding(4), outline=color)
+      if face_centroid is not None:
+        color = RED
+        box = face_centroid.bounding(max_size * face_centroid.weight)
+        d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
+        d.ellipse(box, outline=color)
+        if len(face_points) > 1:
+          for f in face_points:
+            d.rectangle(f.bounding(4), outline=color)
+
+      d.ellipse(average_point.bounding(max_size), outline=GREEN)
+      
+    return average_point
+
+
+def image_face_points(im, settings):
+    if settings.dnn_model_path is not None:
+      detector = cv2.FaceDetectorYN.create(
+          settings.dnn_model_path,
+          "",
+          (im.width, im.height),
+          0.9, # score threshold
+          0.3, # nms threshold
+          5000 # keep top k before nms
+      )
+      faces = detector.detect(np.array(im))
+      results = []
+      if faces[1] is not None:
+        for face in faces[1]:
+          x = face[0]
+          y = face[1]
+          w = face[2]
+          h = face[3]
+          results.append(
+            PointOfInterest(
+              int(x + (w * 0.5)), # face focus left/right is center
+              int(y + (h * 0.33)), # face focus up/down is close to the top of the head
+              size = w,
+              weight = 1/len(faces[1])
+            )
+          )
+      return results
+    else:
+      np_im = np.array(im)
+      gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+
+      tries = [
+        [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
+        [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
+        [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
+        [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
+        [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
+        [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
+        [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
+        [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
+      ]
+      for t in tries:
+        classifier = cv2.CascadeClassifier(t[0])
+        minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+        try:
+          faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+            minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+        except:
+          continue
+
+        if len(faces) > 0:
+          rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+          return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
+    return []
+
+
+def image_corner_points(im, settings):
+    grayscale = im.convert("L")
+
+    # naive attempt at preventing focal points from collecting at watermarks near the bottom
+    gd = ImageDraw.Draw(grayscale)
+    gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+
+    np_im = np.array(grayscale)
+
+    points = cv2.goodFeaturesToTrack(
+        np_im,
+        maxCorners=100,
+        qualityLevel=0.04,
+        minDistance=min(grayscale.width, grayscale.height)*0.06,
+        useHarrisDetector=False,
+    )
+
+    if points is None:
+        return []
+
+    focal_points = []
+    for point in points:
+      x, y = point.ravel()
+      focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
+
+    return focal_points
+
+
+def image_entropy_points(im, settings):
+    landscape = im.height < im.width
+    portrait = im.height > im.width
+    if landscape:
+      move_idx = [0, 2]
+      move_max = im.size[0]
+    elif portrait:
+      move_idx = [1, 3]
+      move_max = im.size[1]
+    else:
+      return []
+
+    e_max = 0
+    crop_current = [0, 0, settings.crop_width, settings.crop_height]
+    crop_best = crop_current
+    while crop_current[move_idx[1]] < move_max:
+        crop = im.crop(tuple(crop_current))
+        e = image_entropy(crop)
+
+        if (e > e_max):
+          e_max = e
+          crop_best = list(crop_current)
+
+        crop_current[move_idx[0]] += 4
+        crop_current[move_idx[1]] += 4
+
+    x_mid = int(crop_best[0] + settings.crop_width/2)
+    y_mid = int(crop_best[1] + settings.crop_height/2)
+
+    return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
+
+
+def image_entropy(im):
+    # greyscale image entropy
+    # band = np.asarray(im.convert("L"))
+    band = np.asarray(im.convert("1"), dtype=np.uint8)
+    hist, _ = np.histogram(band, bins=range(0, 256))
+    hist = hist[hist > 0]
+    return -np.log2(hist / hist.sum()).sum()
+
+def centroid(pois):
+  x = [poi.x for poi in pois]
+  y = [poi.y for poi in pois]
+  return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
+
+
+def poi_average(pois, settings):
+    weight = 0.0
+    x = 0.0
+    y = 0.0
+    for poi in pois:
+        weight += poi.weight
+        x += poi.x * poi.weight
+        y += poi.y * poi.weight
+    avg_x = round(weight and x / weight)
+    avg_y = round(weight and y / weight)
+
+    return PointOfInterest(avg_x, avg_y)
+
+
+def is_landscape(w, h):
+  return w > h
+
+
+def is_portrait(w, h):
+  return h > w
+
+
+def is_square(w, h):
+  return w == h
+
+
+def download_and_cache_models(dirname):
+  download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
+  model_file_name = 'face_detection_yunet.onnx'
+
+  if not os.path.exists(dirname):
+    os.makedirs(dirname)
+
+  cache_file = os.path.join(dirname, model_file_name)
+  if not os.path.exists(cache_file):
+    print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
+    response = requests.get(download_url)
+    with open(cache_file, "wb") as f:
+      f.write(response.content)
+
+  if os.path.exists(cache_file):
+    return cache_file
+  return None
+
+
+class PointOfInterest:
+  def __init__(self, x, y, weight=1.0, size=10):
+    self.x = x
+    self.y = y
+    self.weight = weight
+    self.size = size
+
+  def bounding(self, size):
+    return [
+      self.x - size//2,
+      self.y - size//2,
+      self.x + size//2,
+      self.y + size//2
+    ]
+
+
+class Settings:
+  def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
+    self.crop_width = crop_width
+    self.crop_height = crop_height
+    self.corner_points_weight = corner_points_weight
+    self.entropy_points_weight = entropy_points_weight
+    self.face_points_weight = face_points_weight
+    self.annotate_image = annotate_image
+    self.destop_view_image = False
+    self.dnn_model_path = dnn_model_path
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 23bb4b6a..88d68c76 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -3,7 +3,7 @@ import numpy as np
 import PIL
 import torch
 from PIL import Image
-from torch.utils.data import Dataset
+from torch.utils.data import Dataset, DataLoader
 from torchvision import transforms
 
 import random
@@ -11,25 +11,28 @@ import tqdm
 from modules import devices, shared
 import re
 
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
 re_numbers_at_start = re.compile(r"^[-\d]+\s*")
 
 
 class DatasetEntry:
-    def __init__(self, filename=None, latent=None, filename_text=None):
+    def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
         self.filename = filename
-        self.latent = latent
         self.filename_text = filename_text
-        self.cond = None
-        self.cond_text = None
+        self.latent_dist = latent_dist
+        self.latent_sample = latent_sample
+        self.cond = cond
+        self.cond_text = cond_text
+        self.pixel_values = pixel_values
 
 
 class PersonalizedBase(Dataset):
-    def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+    def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
         re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
 
         self.placeholder_token = placeholder_token
 
-        self.batch_size = batch_size
         self.width = width
         self.height = height
         self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -42,12 +45,19 @@ class PersonalizedBase(Dataset):
         self.lines = lines
 
         assert data_root, 'dataset directory not specified'
-
-        cond_model = shared.sd_model.cond_stage_model
+        assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+        assert os.listdir(data_root), "Dataset directory is empty"
 
         self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
+
+
+        self.shuffle_tags = shuffle_tags
+        self.tag_drop_out = tag_drop_out
+
         print("Preparing dataset...")
         for path in tqdm.tqdm(self.image_paths):
+            if shared.state.interrupted:
+                raise Exception("interrupted")
             try:
                 image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
             except Exception:
@@ -69,53 +79,94 @@ class PersonalizedBase(Dataset):
             npimage = np.array(image).astype(np.uint8)
             npimage = (npimage / 127.5 - 1.0).astype(np.float32)
 
-            torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
-            torchdata = torch.moveaxis(torchdata, 2, 0)
+            torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
+            latent_sample = None
 
-            init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
-            init_latent = init_latent.to(devices.cpu)
+            with devices.autocast():
+                latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
 
-            entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
+            if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
+                latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+                latent_sampling_method = "once"
+                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
+            elif latent_sampling_method == "deterministic":
+                # Works only for DiagonalGaussianDistribution
+                latent_dist.std = 0
+                latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
+            elif latent_sampling_method == "random":
+                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
 
-            if include_cond:
+            if not (self.tag_drop_out != 0 or self.shuffle_tags):
                 entry.cond_text = self.create_text(filename_text)
-                entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
+
+            if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
+                with devices.autocast():
+                    entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
 
             self.dataset.append(entry)
+            del torchdata
+            del latent_dist
+            del latent_sample
 
-        assert len(self.dataset) > 1, "No images have been found in the dataset."
-        self.length = len(self.dataset) * repeats // batch_size
-
-        self.initial_indexes = np.arange(len(self.dataset))
-        self.indexes = None
-        self.shuffle()
-
-    def shuffle(self):
-        self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+        self.length = len(self.dataset)
+        assert self.length > 0, "No images have been found in the dataset."
+        self.batch_size = min(batch_size, self.length)
+        self.gradient_step = min(gradient_step, self.length // self.batch_size)
+        self.latent_sampling_method = latent_sampling_method
 
     def create_text(self, filename_text):
         text = random.choice(self.lines)
+        tags = filename_text.split(',')
+        if self.tag_drop_out != 0:
+            tags = [t for t in tags if random.random() > self.tag_drop_out]
+        if self.shuffle_tags:
+            random.shuffle(tags)
+        text = text.replace("[filewords]", ','.join(tags))
         text = text.replace("[name]", self.placeholder_token)
-        text = text.replace("[filewords]", filename_text)
         return text
 
     def __len__(self):
         return self.length
 
     def __getitem__(self, i):
-        res = []
+        entry = self.dataset[i]
+        if self.tag_drop_out != 0 or self.shuffle_tags:
+            entry.cond_text = self.create_text(entry.filename_text)
+        if self.latent_sampling_method == "random":
+            entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
+        return entry
 
-        for j in range(self.batch_size):
-            position = i * self.batch_size + j
-            if position % len(self.indexes) == 0:
-                self.shuffle()
+class PersonalizedDataLoader(DataLoader):
+    def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
+        super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
+        if latent_sampling_method == "random":
+            self.collate_fn = collate_wrapper_random
+        else:
+            self.collate_fn = collate_wrapper
 
-            index = self.indexes[position % len(self.indexes)]
-            entry = self.dataset[index]
 
-            if entry.cond is None:
-                entry.cond_text = self.create_text(entry.filename_text)
+class BatchLoader:
+    def __init__(self, data):
+        self.cond_text = [entry.cond_text for entry in data]
+        self.cond = [entry.cond for entry in data]
+        self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
+        #self.emb_index = [entry.emb_index for entry in data]
+        #print(self.latent_sample.device)
 
-            res.append(entry)
+    def pin_memory(self):
+        self.latent_sample = self.latent_sample.pin_memory()
+        return self
 
-        return res
+def collate_wrapper(batch):
+    return BatchLoader(batch)
+
+class BatchLoaderRandom(BatchLoader):
+    def __init__(self, data):
+        super().__init__(data)
+
+    def pin_memory(self):
+        return self
+
+def collate_wrapper_random(batch):
+    return BatchLoaderRandom(batch)
\ No newline at end of file
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index 898ce3b3..ea653806 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -5,6 +5,7 @@ import zlib
 from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
 from fonts.ttf import Roboto
 import torch
+from modules.shared import opts
 
 
 class EmbeddingEncoder(json.JSONEncoder):
@@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
     from math import cos
 
     image = srcimage.copy()
-
+    fontsize = 32
     if textfont is None:
         try:
             textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
@@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
     image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
 
     draw = ImageDraw.Draw(image)
-    fontsize = 32
+
     font = ImageFont.truetype(textfont, fontsize)
     padding = 10
 
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 2062726a..dd0c0ad1 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -4,30 +4,37 @@ import tqdm
 class LearnScheduleIterator:
     def __init__(self, learn_rate, max_steps, cur_step=0):
         """
-        specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+        specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
         """
 
         pairs = learn_rate.split(',')
         self.rates = []
         self.it = 0
         self.maxit = 0
-        for i, pair in enumerate(pairs):
-            tmp = pair.split(':')
-            if len(tmp) == 2:
-                step = int(tmp[1])
-                if step > cur_step:
-                    self.rates.append((float(tmp[0]), min(step, max_steps)))
-                    self.maxit += 1
-                    if step > max_steps:
+        try:
+            for i, pair in enumerate(pairs):
+                if not pair.strip():
+                    continue
+                tmp = pair.split(':')
+                if len(tmp) == 2:
+                    step = int(tmp[1])
+                    if step > cur_step:
+                        self.rates.append((float(tmp[0]), min(step, max_steps)))
+                        self.maxit += 1
+                        if step > max_steps:
+                            return
+                    elif step == -1:
+                        self.rates.append((float(tmp[0]), max_steps))
+                        self.maxit += 1
                         return
-                elif step == -1:
+                else:
                     self.rates.append((float(tmp[0]), max_steps))
                     self.maxit += 1
                     return
-            else:
-                self.rates.append((float(tmp[0]), max_steps))
-                self.maxit += 1
-                return
+            assert self.rates
+        except (ValueError, AssertionError):
+            raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
+
 
     def __iter__(self):
         return self
@@ -52,7 +59,7 @@ class LearnRateScheduler:
         self.finished = False
 
     def apply(self, optimizer, step_number):
-        if step_number <= self.end_step:
+        if step_number < self.end_step:
             return
 
         try:
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 886cf0c3..feb876c6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,27 +1,26 @@
 import os
 from PIL import Image, ImageOps
+import math
 import platform
 import sys
 import tqdm
 import time
 
-from modules import shared, images
+from modules import shared, images, deepbooru
+from modules.paths import models_path
 from modules.shared import opts, cmd_opts
-if cmd_opts.deepdanbooru:
-    import modules.deepbooru as deepbooru
+from modules.textual_inversion import autocrop
 
 
-def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
     try:
         if process_caption:
             shared.interrogator.load()
 
         if process_caption_deepbooru:
-            db_opts = deepbooru.create_deepbooru_opts()
-            db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
-            deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
+            deepbooru.model.start()
 
-        preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+        preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
 
     finally:
 
@@ -29,88 +28,169 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
             shared.interrogator.send_blip_to_ram()
 
         if process_caption_deepbooru:
-            deepbooru.release_process()
+            deepbooru.model.stop()
 
 
+def listfiles(dirname):
+    return os.listdir(dirname)
 
-def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+
+class PreprocessParams:
+    src = None
+    dstdir = None
+    subindex = 0
+    flip = False
+    process_caption = False
+    process_caption_deepbooru = False
+    preprocess_txt_action = None
+
+
+def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
+    caption = ""
+
+    if params.process_caption:
+        caption += shared.interrogator.generate_caption(image)
+
+    if params.process_caption_deepbooru:
+        if len(caption) > 0:
+            caption += ", "
+        caption += deepbooru.model.tag_multi(image)
+
+    filename_part = params.src
+    filename_part = os.path.splitext(filename_part)[0]
+    filename_part = os.path.basename(filename_part)
+
+    basename = f"{index:05}-{params.subindex}-{filename_part}"
+    image.save(os.path.join(params.dstdir, f"{basename}.png"))
+
+    if params.preprocess_txt_action == 'prepend' and existing_caption:
+        caption = existing_caption + ' ' + caption
+    elif params.preprocess_txt_action == 'append' and existing_caption:
+        caption = caption + ' ' + existing_caption
+    elif params.preprocess_txt_action == 'copy' and existing_caption:
+        caption = existing_caption
+
+    caption = caption.strip()
+
+    if len(caption) > 0:
+        with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
+            file.write(caption)
+
+    params.subindex += 1
+
+
+def save_pic(image, index, params, existing_caption=None):
+    save_pic_with_caption(image, index, params, existing_caption=existing_caption)
+
+    if params.flip:
+        save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
+
+
+def split_pic(image, inverse_xy, width, height, overlap_ratio):
+    if inverse_xy:
+        from_w, from_h = image.height, image.width
+        to_w, to_h = height, width
+    else:
+        from_w, from_h = image.width, image.height
+        to_w, to_h = width, height
+    h = from_h * to_w // from_w
+    if inverse_xy:
+        image = image.resize((h, to_w))
+    else:
+        image = image.resize((to_w, h))
+
+    split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+    y_step = (h - to_h) / (split_count - 1)
+    for i in range(split_count):
+        y = int(y_step * i)
+        if inverse_xy:
+            splitted = image.crop((y, 0, y + to_h, to_w))
+        else:
+            splitted = image.crop((0, y, to_w, y + to_h))
+        yield splitted
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
     width = process_width
     height = process_height
     src = os.path.abspath(process_src)
     dst = os.path.abspath(process_dst)
+    split_threshold = max(0.0, min(1.0, split_threshold))
+    overlap_ratio = max(0.0, min(0.9, overlap_ratio))
 
     assert src != dst, 'same directory specified as source and destination'
 
     os.makedirs(dst, exist_ok=True)
 
-    files = os.listdir(src)
+    files = listfiles(src)
 
+    shared.state.job = "preprocess"
     shared.state.textinfo = "Preprocessing..."
     shared.state.job_count = len(files)
 
-    def save_pic_with_caption(image, index):
-        caption = ""
-
-        if process_caption:
-            caption += shared.interrogator.generate_caption(image)
-
-        if process_caption_deepbooru:
-            if len(caption) > 0:
-                caption += ", "
-            caption += deepbooru.get_tags_from_process(image)
-
-        filename_part = filename
-        filename_part = os.path.splitext(filename_part)[0]
-        filename_part = os.path.basename(filename_part)
-
-        basename = f"{index:05}-{subindex[0]}-{filename_part}"
-        image.save(os.path.join(dst, f"{basename}.png"))
-
-        if len(caption) > 0:
-            with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
-                file.write(caption)
-
-        subindex[0] += 1
-
-    def save_pic(image, index):
-        save_pic_with_caption(image, index)
-
-        if process_flip:
-            save_pic_with_caption(ImageOps.mirror(image), index)
+    params = PreprocessParams()
+    params.dstdir = dst
+    params.flip = process_flip
+    params.process_caption = process_caption
+    params.process_caption_deepbooru = process_caption_deepbooru
+    params.preprocess_txt_action = preprocess_txt_action
 
     for index, imagefile in enumerate(tqdm.tqdm(files)):
-        subindex = [0]
+        params.subindex = 0
         filename = os.path.join(src, imagefile)
         try:
             img = Image.open(filename).convert("RGB")
         except Exception:
             continue
 
+        params.src = filename
+
+        existing_caption = None
+        existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
+        if os.path.exists(existing_caption_filename):
+            with open(existing_caption_filename, 'r', encoding="utf8") as file:
+                existing_caption = file.read()
+
         if shared.state.interrupted:
             break
 
-        ratio = img.height / img.width
-        is_tall = ratio > 1.35
-        is_wide = ratio < 1 / 1.35
-
-        if process_split and is_tall:
-            img = img.resize((width, height * img.height // img.width))
-
-            top = img.crop((0, 0, width, height))
-            save_pic(top, index)
-
-            bot = img.crop((0, img.height - height, width, img.height))
-            save_pic(bot, index)
-        elif process_split and is_wide:
-            img = img.resize((width * img.width // img.height, height))
-
-            left = img.crop((0, 0, width, height))
-            save_pic(left, index)
-
-            right = img.crop((img.width - width, 0, img.width, height))
-            save_pic(right, index)
+        if img.height > img.width:
+            ratio = (img.width * height) / (img.height * width)
+            inverse_xy = False
         else:
+            ratio = (img.height * width) / (img.width * height)
+            inverse_xy = True
+
+        process_default_resize = True
+
+        if process_split and ratio < 1.0 and ratio <= split_threshold:
+            for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
+                save_pic(splitted, index, params, existing_caption=existing_caption)
+            process_default_resize = False
+
+        if process_focal_crop and img.height != img.width:
+
+            dnn_model_path = None
+            try:
+                dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
+            except Exception as e:
+                print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
+
+            autocrop_settings = autocrop.Settings(
+                crop_width = width,
+                crop_height = height,
+                face_points_weight = process_focal_crop_face_weight,
+                entropy_points_weight = process_focal_crop_entropy_weight,
+                corner_points_weight = process_focal_crop_edges_weight,
+                annotate_image = process_focal_crop_debug,
+                dnn_model_path = dnn_model_path,
+            )
+            for focal in autocrop.crop_image(img, autocrop_settings):
+                save_pic(focal, index, params, existing_caption=existing_caption)
+            process_default_resize = False
+
+        if process_default_resize:
             img = images.resize_image(1, img, width, height)
-            save_pic(img, index)
+            save_pic(img, index, params, existing_caption=existing_caption)
 
         shared.state.nextjob()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 3be69562..2250e41b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -10,7 +10,7 @@ import csv
 
 from PIL import Image, PngImagePlugin
 
-from modules import shared, devices, sd_hijack, processing, sd_models
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
 import modules.textual_inversion.dataset
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
 
@@ -23,9 +23,12 @@ class Embedding:
         self.vec = vec
         self.name = name
         self.step = step
+        self.shape = None
+        self.vectors = 0
         self.cached_checksum = None
         self.sd_checkpoint = None
         self.sd_checkpoint_name = None
+        self.optimizer_state_dict = None
 
     def save(self, filename):
         embedding_data = {
@@ -39,6 +42,13 @@ class Embedding:
 
         torch.save(embedding_data, filename)
 
+        if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
+            optimizer_saved_dict = {
+                'hash': self.checksum(),
+                'optimizer_state_dict': self.optimizer_state_dict,
+            }
+            torch.save(optimizer_saved_dict, filename + '.optim')
+
     def checksum(self):
         if self.cached_checksum is not None:
             return self.cached_checksum
@@ -57,14 +67,17 @@ class EmbeddingDatabase:
     def __init__(self, embeddings_dir):
         self.ids_lookup = {}
         self.word_embeddings = {}
+        self.skipped_embeddings = {}
         self.dir_mtime = None
         self.embeddings_dir = embeddings_dir
+        self.expected_shape = -1
 
     def register_embedding(self, embedding, model):
 
         self.word_embeddings[embedding.name] = embedding
 
-        ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
+        # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
+        ids = model.cond_stage_model.tokenize([embedding.name])[0]
 
         first_id = ids[0]
         if first_id not in self.ids_lookup:
@@ -74,21 +87,26 @@ class EmbeddingDatabase:
 
         return embedding
 
-    def load_textual_inversion_embeddings(self):
+    def get_expected_shape(self):
+        vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
+        return vec.shape[1]
+
+    def load_textual_inversion_embeddings(self, force_reload = False):
         mt = os.path.getmtime(self.embeddings_dir)
-        if self.dir_mtime is not None and mt <= self.dir_mtime:
+        if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
             return
 
         self.dir_mtime = mt
         self.ids_lookup.clear()
         self.word_embeddings.clear()
+        self.skipped_embeddings.clear()
+        self.expected_shape = self.get_expected_shape()
 
         def process_file(path, filename):
-            name = os.path.splitext(filename)[0]
+            name, ext = os.path.splitext(filename)
+            ext = ext.upper()
 
-            data = []
-
-            if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+            if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
                 embed_image = Image.open(path)
                 if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
                     data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@@ -96,8 +114,10 @@ class EmbeddingDatabase:
                 else:
                     data = extract_image_data_embed(embed_image)
                     name = data.get('name', name)
-            else:
+            elif ext in ['.BIN', '.PT']:
                 data = torch.load(path, map_location="cpu")
+            else:
+                return
 
             # textual inversion embeddings
             if 'string_to_param' in data:
@@ -119,9 +139,15 @@ class EmbeddingDatabase:
             vec = emb.detach().to(devices.device, dtype=torch.float32)
             embedding = Embedding(vec, name)
             embedding.step = data.get('step', None)
-            embedding.sd_checkpoint = data.get('hash', None)
+            embedding.sd_checkpoint = data.get('sd_checkpoint', None)
             embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
-            self.register_embedding(embedding, shared.sd_model)
+            embedding.vectors = vec.shape[0]
+            embedding.shape = vec.shape[-1]
+
+            if self.expected_shape == -1 or self.expected_shape == embedding.shape:
+                self.register_embedding(embedding, shared.sd_model)
+            else:
+                self.skipped_embeddings[name] = embedding
 
         for fn in os.listdir(self.embeddings_dir):
             try:
@@ -132,12 +158,13 @@ class EmbeddingDatabase:
 
                 process_file(fullfn, fn)
             except Exception:
-                print(f"Error loading emedding {fn}:", file=sys.stderr)
+                print(f"Error loading embedding {fn}:", file=sys.stderr)
                 print(traceback.format_exc(), file=sys.stderr)
                 continue
 
-        print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
-        print("Embeddings:", ', '.join(self.word_embeddings.keys()))
+        print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
+        if len(self.skipped_embeddings) > 0:
+            print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
 
     def find_embedding_at_position(self, tokens, offset):
         token = tokens[offset]
@@ -153,19 +180,23 @@ class EmbeddingDatabase:
         return None, None
 
 
-def create_embedding(name, num_vectors_per_token, init_text='*'):
+def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
     cond_model = shared.sd_model.cond_stage_model
-    embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
 
-    ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
-    embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+    with devices.autocast():
+        cond_model([""])  # will send cond model to GPU if lowvram/medvram is active
+
+    embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
     vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
 
     for i in range(num_vectors_per_token):
         vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
 
+    # Remove illegal characters from name.
+    name = "".join( x for x in name if (x.isalnum() or x in "._- "))
     fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
-    assert not os.path.exists(fn), f"file {fn} already exists"
+    if not overwrite_old:
+        assert not os.path.exists(fn), f"file {fn} already exists"
 
     embedding = Embedding(vec, name)
     embedding.step = 0
@@ -180,7 +211,6 @@ def write_loss(log_directory, filename, step, epoch_len, values):
 
     if step % shared.opts.training_write_csv_every != 0:
         return
-
     write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
 
     with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
@@ -189,26 +219,52 @@ def write_loss(log_directory, filename, step, epoch_len, values):
         if write_csv_header:
             csv_writer.writeheader()
 
-        epoch = step // epoch_len
-        epoch_step = step - epoch * epoch_len
+        epoch = (step - 1) // epoch_len
+        epoch_step = (step - 1) % epoch_len
 
         csv_writer.writerow({
-            "step": step + 1,
-            "epoch": epoch + 1,
-            "epoch_step": epoch_step + 1,
+            "step": step,
+            "epoch": epoch,
+            "epoch_step": epoch_step,
             **values,
         })
 
+def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+    assert model_name, f"{name} not selected"
+    assert learn_rate, "Learning rate is empty or 0"
+    assert isinstance(batch_size, int), "Batch size must be integer"
+    assert batch_size > 0, "Batch size must be positive"
+    assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
+    assert gradient_step > 0, "Gradient accumulation step must be positive"
+    assert data_root, "Dataset directory is empty"
+    assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+    assert os.listdir(data_root), "Dataset directory is empty"
+    assert template_file, "Prompt template file is empty"
+    assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+    assert steps, "Max steps is empty or 0"
+    assert isinstance(steps, int), "Max steps must be integer"
+    assert steps > 0 , "Max steps must be positive"
+    assert isinstance(save_model_every, int), "Save {name} must be integer"
+    assert save_model_every >= 0 , "Save {name} must be positive or 0"
+    assert isinstance(create_image_every, int), "Create image must be integer"
+    assert create_image_every >= 0 , "Create image must be positive or 0"
+    if save_model_every or create_image_every:
+        assert log_directory, "Log directory is empty"
 
-def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
-    assert embedding_name, 'embedding not selected'
 
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+    save_embedding_every = save_embedding_every or 0
+    create_image_every = create_image_every or 0
+    validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+
+    shared.state.job = "train-embedding"
     shared.state.textinfo = "Initializing textual inversion training..."
     shared.state.job_count = steps
 
     filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
 
     log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
+    unload = shared.opts.unload_models_when_training
 
     if save_embedding_every > 0:
         embedding_dir = os.path.join(log_directory, "embeddings")
@@ -227,148 +283,241 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
         os.makedirs(images_embeds_dir, exist_ok=True)
     else:
         images_embeds_dir = None
-        
-    cond_model = shared.sd_model.cond_stage_model
-
-    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
-    with torch.autocast("cuda"):
-        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
 
     hijack = sd_hijack.model_hijack
 
     embedding = hijack.embedding_db.word_embeddings[embedding_name]
-    embedding.vec.requires_grad = True
+    checkpoint = sd_models.select_checkpoint()
 
-    losses = torch.zeros((32,))
+    initial_step = embedding.step or 0
+    if initial_step >= steps:
+        shared.state.textinfo = "Model has already been trained beyond specified max steps"
+        return embedding, filename
+    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
+
+    # dataset loading may take a while, so input validations and early returns should be done before this
+    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+    old_parallel_processing_allowed = shared.parallel_processing_allowed
+
+    pin_memory = shared.opts.pin_memory
+
+    ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+
+    latent_sampling_method = ds.latent_sampling_method
+
+    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
+
+    if unload:
+        shared.parallel_processing_allowed = False
+        shared.sd_model.first_stage_model.to(devices.cpu)
+
+    embedding.vec.requires_grad = True
+    optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+    if shared.opts.save_optimizer_state:
+        optimizer_state_dict = None
+        if os.path.exists(filename + '.optim'):
+            optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
+            if embedding.checksum() == optimizer_saved_dict.get('hash', None):
+                optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+    
+        if optimizer_state_dict is not None:
+            optimizer.load_state_dict(optimizer_state_dict)
+            print("Loaded existing optimizer from checkpoint")
+        else:
+            print("No saved optimizer exists in checkpoint")
+
+    scaler = torch.cuda.amp.GradScaler()
+
+    batch_size = ds.batch_size
+    gradient_step = ds.gradient_step
+    # n steps = batch_size * gradient_step * n image processed
+    steps_per_epoch = len(ds) // batch_size // gradient_step
+    max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
+    loss_step = 0
+    _loss_step = 0 #internal
 
     last_saved_file = "<none>"
     last_saved_image = "<none>"
+    forced_filename = "<none>"
     embedding_yet_to_be_embedded = False
 
-    ititial_step = embedding.step or 0
-    if ititial_step > steps:
-        return embedding, filename
+    is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+    img_c = None
 
-    scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
-    optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
+    pbar = tqdm.tqdm(total=steps - initial_step)
+    try:
+        for i in range((steps-initial_step) * gradient_step):
+            if scheduler.finished:
+                break
+            if shared.state.interrupted:
+                break
+            for j, batch in enumerate(dl):
+                # works as a drop_last=True for gradient accumulation
+                if j == max_steps_per_epoch:
+                    break
+                scheduler.apply(optimizer, embedding.step)
+                if scheduler.finished:
+                    break
+                if shared.state.interrupted:
+                    break
 
-    pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
-    for i, entries in pbar:
-        embedding.step = i + ititial_step
+                with devices.autocast():
+                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+                    c = shared.sd_model.cond_stage_model(batch.cond_text)
 
-        scheduler.apply(optimizer, embedding.step)
-        if scheduler.finished:
-            break
+                    if is_training_inpainting_model:
+                        if img_c is None:
+                            img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
 
-        if shared.state.interrupted:
-            break
+                        cond = {"c_concat": [img_c], "c_crossattn": [c]}
+                    else:
+                        cond = c
 
-        with torch.autocast("cuda"):
-            c = cond_model([entry.cond_text for entry in entries])
-            x = torch.stack([entry.latent for entry in entries]).to(devices.device)
-            loss = shared.sd_model(x, c)[0]
-            del x
+                    loss = shared.sd_model(x, cond)[0] / gradient_step
+                    del x
 
-            losses[embedding.step % losses.shape[0]] = loss.item()
+                    _loss_step += loss.item()
+                scaler.scale(loss).backward()
 
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
+                # go back until we reach gradient accumulation steps
+                if (j + 1) % gradient_step != 0:
+                    continue
+                scaler.step(optimizer)
+                scaler.update()
+                embedding.step += 1
+                pbar.update()
+                optimizer.zero_grad(set_to_none=True)
+                loss_step = _loss_step
+                _loss_step = 0
 
-        epoch_num = embedding.step // len(ds)
-        epoch_step = embedding.step - (epoch_num * len(ds)) + 1
+                steps_done = embedding.step + 1
 
-        pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
+                epoch_num = embedding.step // steps_per_epoch
+                epoch_step = embedding.step % steps_per_epoch
 
-        if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
-            last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
-            embedding.save(last_saved_file)
-            embedding_yet_to_be_embedded = True
+                pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+                if embedding_dir is not None and steps_done % save_embedding_every == 0:
+                    # Before saving, change name to match current checkpoint.
+                    embedding_name_every = f'{embedding_name}-{steps_done}'
+                    last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
+                    save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+                    embedding_yet_to_be_embedded = True
 
-        write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
-            "loss": f"{losses.mean():.7f}",
-            "learn_rate": scheduler.learn_rate
-        })
+                write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
+                    "loss": f"{loss_step:.7f}",
+                    "learn_rate": scheduler.learn_rate
+                })
 
-        if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
-            last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
+                if images_dir is not None and steps_done % create_image_every == 0:
+                    forced_filename = f'{embedding_name}-{steps_done}'
+                    last_saved_image = os.path.join(images_dir, forced_filename)
 
-            p = processing.StableDiffusionProcessingTxt2Img(
-                sd_model=shared.sd_model,
-                do_not_save_grid=True,
-                do_not_save_samples=True,
-                do_not_reload_embeddings=True,
-            )
+                    shared.sd_model.first_stage_model.to(devices.device)
 
-            if preview_from_txt2img:
-                p.prompt = preview_prompt
-                p.negative_prompt = preview_negative_prompt
-                p.steps = preview_steps
-                p.sampler_index = preview_sampler_index
-                p.cfg_scale = preview_cfg_scale
-                p.seed = preview_seed
-                p.width = preview_width
-                p.height = preview_height
-            else:
-                p.prompt = entries[0].cond_text
-                p.steps = 20
-                p.width = training_width
-                p.height = training_height
+                    p = processing.StableDiffusionProcessingTxt2Img(
+                        sd_model=shared.sd_model,
+                        do_not_save_grid=True,
+                        do_not_save_samples=True,
+                        do_not_reload_embeddings=True,
+                    )
 
-            preview_text = p.prompt
+                    if preview_from_txt2img:
+                        p.prompt = preview_prompt
+                        p.negative_prompt = preview_negative_prompt
+                        p.steps = preview_steps
+                        p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+                        p.cfg_scale = preview_cfg_scale
+                        p.seed = preview_seed
+                        p.width = preview_width
+                        p.height = preview_height
+                    else:
+                        p.prompt = batch.cond_text[0]
+                        p.steps = 20
+                        p.width = training_width
+                        p.height = training_height
 
-            processed = processing.process_images(p)
-            image = processed.images[0]
+                    preview_text = p.prompt
 
-            shared.state.current_image = image
+                    processed = processing.process_images(p)
+                    image = processed.images[0] if len(processed.images) > 0 else None
 
-            if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+                    if unload:
+                        shared.sd_model.first_stage_model.to(devices.cpu)
 
-                last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+                    if image is not None:
+                        shared.state.current_image = image
+                        last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+                        last_saved_image += f", prompt: {preview_text}"
 
-                info = PngImagePlugin.PngInfo()
-                data = torch.load(last_saved_file)
-                info.add_text("sd-ti-embedding", embedding_to_b64(data))
+                    if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
 
-                title = "<{}>".format(data.get('name', '???'))
+                        last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
 
-                try:
-                    vectorSize = list(data['string_to_param'].values())[0].shape[0]
-                except Exception as e:
-                    vectorSize = '?'
+                        info = PngImagePlugin.PngInfo()
+                        data = torch.load(last_saved_file)
+                        info.add_text("sd-ti-embedding", embedding_to_b64(data))
 
-                checkpoint = sd_models.select_checkpoint()
-                footer_left = checkpoint.model_name
-                footer_mid = '[{}]'.format(checkpoint.hash)
-                footer_right = '{}v {}s'.format(vectorSize, embedding.step)
+                        title = "<{}>".format(data.get('name', '???'))
 
-                captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
-                captioned_image = insert_image_data_embed(captioned_image, data)
+                        try:
+                            vectorSize = list(data['string_to_param'].values())[0].shape[0]
+                        except Exception as e:
+                            vectorSize = '?'
 
-                captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
-                embedding_yet_to_be_embedded = False
+                        checkpoint = sd_models.select_checkpoint()
+                        footer_left = checkpoint.model_name
+                        footer_mid = '[{}]'.format(checkpoint.hash)
+                        footer_right = '{}v {}s'.format(vectorSize, steps_done)
 
-            image.save(last_saved_image)
+                        captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+                        captioned_image = insert_image_data_embed(captioned_image, data)
 
-            last_saved_image += f", prompt: {preview_text}"
+                        captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+                        embedding_yet_to_be_embedded = False
 
-        shared.state.job_no = embedding.step
+                    last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+                    last_saved_image += f", prompt: {preview_text}"
 
-        shared.state.textinfo = f"""
+                shared.state.job_no = embedding.step
+
+                shared.state.textinfo = f"""
 <p>
-Loss: {losses.mean():.7f}<br/>
-Step: {embedding.step}<br/>
-Last prompt: {html.escape(entries[0].cond_text)}<br/>
+Loss: {loss_step:.7f}<br/>
+Step: {steps_done}<br/>
+Last prompt: {html.escape(batch.cond_text[0])}<br/>
 Last saved embedding: {html.escape(last_saved_file)}<br/>
 Last saved image: {html.escape(last_saved_image)}<br/>
 </p>
 """
-
-    checkpoint = sd_models.select_checkpoint()
-
-    embedding.sd_checkpoint = checkpoint.hash
-    embedding.sd_checkpoint_name = checkpoint.model_name
-    embedding.cached_checksum = None
-    embedding.save(filename)
+        filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+        save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+    except Exception:
+        print(traceback.format_exc(), file=sys.stderr)
+        pass
+    finally:
+        pbar.leave = False
+        pbar.close()
+        shared.sd_model.first_stage_model.to(devices.device)
+        shared.parallel_processing_allowed = old_parallel_processing_allowed
 
     return embedding, filename
+
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+    old_embedding_name = embedding.name
+    old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
+    old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
+    old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
+    try:
+        embedding.sd_checkpoint = checkpoint.hash
+        embedding.sd_checkpoint_name = checkpoint.model_name
+        if remove_cached_checksum:
+            embedding.cached_checksum = None
+        embedding.name = embedding_name
+        embedding.optimizer_state_dict = optimizer.state_dict()
+        embedding.save(filename)
+    except:
+        embedding.sd_checkpoint = old_sd_checkpoint
+        embedding.sd_checkpoint_name = old_sd_checkpoint_name
+        embedding.name = old_embedding_name
+        embedding.cached_checksum = old_cached_checksum
+        raise
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index 36881e7a..35c4feef 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
 from modules import sd_hijack, shared
 
 
-def create_embedding(name, initialization_text, nvpt):
-    filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
+def create_embedding(name, initialization_text, nvpt, overwrite_old):
+    filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
 
     sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
 
@@ -18,15 +18,17 @@ def create_embedding(name, initialization_text, nvpt):
 def preprocess(*args):
     modules.textual_inversion.preprocess.preprocess(*args)
 
-    return "Preprocessing finished.", ""
+    return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
 
 
 def train_embedding(*args):
 
     assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
 
+    apply_optimizations = shared.opts.training_xattention_optimizations
     try:
-        sd_hijack.undo_optimizations()
+        if not apply_optimizations:
+            sd_hijack.undo_optimizations()
 
         embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
 
@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
     except Exception:
         raise
     finally:
-        sd_hijack.apply_optimizations()
+        if not apply_optimizations:
+            sd_hijack.apply_optimizations()
 
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 2381347f..e189a899 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,12 +1,14 @@
 import modules.scripts
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules import sd_samplers
+from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
+    StableDiffusionProcessingImg2Img, process_images
 from modules.shared import opts, cmd_opts
 import modules.shared as shared
 import modules.processing as processing
 from modules.ui import plaintext_to_html
 
 
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
     p = StableDiffusionProcessingTxt2Img(
         sd_model=shared.sd_model,
         outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -20,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
         seed_resize_from_h=seed_resize_from_h,
         seed_resize_from_w=seed_resize_from_w,
         seed_enable_extras=seed_enable_extras,
-        sampler_index=sampler_index,
+        sampler_name=sd_samplers.samplers[sampler_index].name,
         batch_size=batch_size,
         n_iter=n_iter,
         steps=steps,
@@ -31,10 +33,13 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
         tiling=tiling,
         enable_hr=enable_hr,
         denoising_strength=denoising_strength if enable_hr else None,
-        firstphase_width=firstphase_width if enable_hr else None,
-        firstphase_height=firstphase_height if enable_hr else None,
+        hr_scale=hr_scale,
+        hr_upscaler=hr_upscaler,
     )
 
+    p.scripts = modules.scripts.scripts_txt2img
+    p.script_args = args
+
     if cmd_opts.enable_console_prompts:
         print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
 
@@ -43,6 +48,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
     if processed is None:
         processed = process_images(p)
 
+    p.close()
+
     shared.total_tqdm.clear()
 
     generation_info_js = processed.js()
@@ -52,5 +59,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
     if opts.do_not_show_images:
         processed.images = []
 
-    return processed.images, generation_info_js, plaintext_to_html(processed.info)
-
+    return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/ui.py b/modules/ui.py
index 0abd177a..e4859020 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1,62 +1,63 @@
-import base64
 import html
-import io
 import json
 import math
 import mimetypes
 import os
+import platform
 import random
+import subprocess as sp
 import sys
 import tempfile
 import time
 import traceback
-import platform
-import subprocess as sp
 from functools import partial, reduce
 
-import numpy as np
-import torch
-from PIL import Image, PngImagePlugin
-import piexif
-
 import gradio as gr
-import gradio.utils
 import gradio.routes
+import gradio.utils
+import numpy as np
+from PIL import Image, PngImagePlugin
+from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
 
-from modules import sd_hijack, sd_models, localization
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
+from modules.ui_components import FormRow, FormGroup, ToolButton
 from modules.paths import script_path
+
 from modules.shared import opts, cmd_opts, restricted_opts
-if cmd_opts.deepdanbooru:
-    from modules.deepbooru import get_deepbooru_tags
-import modules.shared as shared
-from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.sd_hijack import model_hijack
-import modules.ldsr_model
-import modules.scripts
-import modules.gfpgan_model
+
 import modules.codeformer_model
+import modules.generation_parameters_copypaste as parameters_copypaste
+import modules.gfpgan_model
+import modules.hypernetworks.ui
+import modules.scripts
+import modules.shared as shared
 import modules.styles
-import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
 from modules import prompt_parser
 from modules.images import save_image
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
 import modules.textual_inversion.ui
 import modules.hypernetworks.ui
-import modules.images_history as img_his
+from modules.generation_parameters_copypaste import image_from_url_text
 
 # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
 mimetypes.init()
 mimetypes.add_type('application/javascript', '.js')
 
-
 if not cmd_opts.share and not cmd_opts.listen:
     # fix gradio phoning home
     gradio.utils.version_check = lambda: None
     gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
 
-if cmd_opts.ngrok != None:
+if cmd_opts.ngrok is not None:
     import modules.ngrok as ngrok
     print('ngrok authtoken detected, trying to connect...')
-    ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region)
+    ngrok.connect(
+        cmd_opts.ngrok,
+        cmd_opts.port if cmd_opts.port is not None else 7860,
+        cmd_opts.ngrok_region
+        )
 
 
 def gr_show(visible=True):
@@ -69,57 +70,34 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
 css_hide_progressbar = """
 .wrap .m-12 svg { display:none!important; }
 .wrap .m-12::before { content:"Loading..." }
+.wrap .z-20 svg { display:none!important; }
+.wrap .z-20::before { content:"Loading..." }
 .progress-bar { display:none!important; }
 .meta-text { display:none!important; }
+.meta-text-center { display:none!important; }
 """
 
 # Using constants for these since the variation selector isn't visible.
 # Important that they exactly match script.js for tooltip to work.
 random_symbol = '\U0001f3b2\ufe0f'  # 🎲️
 reuse_symbol = '\u267b\ufe0f'  # ♻️
-art_symbol = '\U0001f3a8'  # 🎨
 paste_symbol = '\u2199\ufe0f'  # ↙
 folder_symbol = '\U0001f4c2'  # 📂
 refresh_symbol = '\U0001f504'  # 🔄
 save_style_symbol = '\U0001f4be'  # 💾
 apply_style_symbol = '\U0001f4cb'  # 📋
+clear_prompt_symbol = '\U0001F5D1'  # 🗑️
 
 
 def plaintext_to_html(text):
     text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
     return text
 
-
-def image_from_url_text(filedata):
-    if type(filedata) == dict and filedata["is_file"]:
-        filename = filedata["name"]
-        tempdir = os.path.normpath(tempfile.gettempdir())
-        normfn = os.path.normpath(filename)
-        assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
-
-        return Image.open(filename)
-
-    if type(filedata) == list:
-        if len(filedata) == 0:
-            return None
-
-        filedata = filedata[0]
-
-    if filedata.startswith("data:image/png;base64,"):
-        filedata = filedata[len("data:image/png;base64,"):]
-
-    filedata = base64.decodebytes(filedata.encode('utf-8'))
-    image = Image.open(io.BytesIO(filedata))
-    return image
-
-
 def send_gradio_gallery_to_image(x):
     if len(x) == 0:
         return None
-
     return image_from_url_text(x[0])
 
-
 def save_files(js_data, images, do_make_zip, index):
     import csv
     filenames = []
@@ -168,7 +146,7 @@ def save_files(js_data, images, do_make_zip, index):
                 filenames.append(os.path.basename(txt_fullfn))
                 fullfns.append(txt_fullfn)
 
-        writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
+        writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
 
     # Make Zip
     if do_make_zip:
@@ -181,84 +159,9 @@ def save_files(js_data, images, do_make_zip, index):
                     zip_file.writestr(filenames[i], f.read())
         fullfns.insert(0, zip_filepath)
 
-    return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
+    return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
 
 
-def save_pil_to_file(pil_image, dir=None):
-    use_metadata = False
-    metadata = PngImagePlugin.PngInfo()
-    for key, value in pil_image.info.items():
-        if isinstance(key, str) and isinstance(value, str):
-            metadata.add_text(key, value)
-            use_metadata = True
-
-    file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
-    pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
-    return file_obj
-
-
-# override save to file function so that it also writes PNG info
-gr.processing_utils.save_pil_to_file = save_pil_to_file
-
-
-def wrap_gradio_call(func, extra_outputs=None):
-    def f(*args, extra_outputs_array=extra_outputs, **kwargs):
-        run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
-        if run_memmon:
-            shared.mem_mon.monitor()
-        t = time.perf_counter()
-
-        try:
-            res = list(func(*args, **kwargs))
-        except Exception as e:
-            # When printing out our debug argument list, do not print out more than a MB of text
-            max_debug_str_len = 131072 # (1024*1024)/8
-
-            print("Error completing request", file=sys.stderr)
-            argStr = f"Arguments: {str(args)} {str(kwargs)}"
-            print(argStr[:max_debug_str_len], file=sys.stderr)
-            if len(argStr) > max_debug_str_len:
-                print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
-
-            print(traceback.format_exc(), file=sys.stderr)
-
-            shared.state.job = ""
-            shared.state.job_count = 0
-
-            if extra_outputs_array is None:
-                extra_outputs_array = [None, '']
-
-            res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
-
-        elapsed = time.perf_counter() - t
-        elapsed_m = int(elapsed // 60)
-        elapsed_s = elapsed % 60
-        elapsed_text = f"{elapsed_s:.2f}s"
-        if (elapsed_m > 0):
-            elapsed_text = f"{elapsed_m}m "+elapsed_text
-
-        if run_memmon:
-            mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
-            active_peak = mem_stats['active_peak']
-            reserved_peak = mem_stats['reserved_peak']
-            sys_peak = mem_stats['system_peak']
-            sys_total = mem_stats['total']
-            sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
-
-            vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
-        else:
-            vram_html = ''
-
-        # last item is always HTML
-        res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
-
-        shared.state.skipped = False
-        shared.state.interrupted = False
-        shared.state.job_count = 0
-
-        return tuple(res)
-
-    return f
 
 
 def calc_time_left(progress, threshold, label, force_display, showTime):
@@ -306,13 +209,8 @@ def check_progress_call(id_part):
     image = gr_show(False)
     preview_visibility = gr_show(False)
 
-    if opts.show_progress_every_n_steps > 0:
-        if shared.parallel_processing_allowed:
-
-            if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
-                shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
-                shared.state.current_image_sampling_step = shared.state.sampling_step
-
+    if opts.show_progress_every_n_steps != 0:
+        shared.state.set_current_image()
         image = shared.state.current_image
 
         if image is None:
@@ -339,13 +237,6 @@ def check_progress_call_initial(id_part):
     return check_progress_call(id_part)
 
 
-def roll_artist(prompt):
-    allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
-    artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
-
-    return prompt + ", " + artist.name if prompt != '' else artist.name
-
-
 def visit(x, func, path=""):
     if hasattr(x, 'children'):
         for c in x.children:
@@ -375,45 +266,41 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name):
 
 
 def interrogate(image):
-    prompt = shared.interrogator.interrogate(image)
+    prompt = shared.interrogator.interrogate(image.convert("RGB"))
 
     return gr_show(True) if prompt is None else prompt
 
 
 def interrogate_deepbooru(image):
-    prompt = get_deepbooru_tags(image)
+    prompt = deepbooru.model.tag(image)
     return gr_show(True) if prompt is None else prompt
 
 
-def create_seed_inputs():
-    with gr.Row():
-        with gr.Box():
-            with gr.Row(elem_id='seed_row'):
-                seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1)
-                seed.style(container=False)
-                random_seed = gr.Button(random_symbol, elem_id='random_seed')
-                reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed')
+def create_seed_inputs(target_interface):
+    with FormRow(elem_id=target_interface + '_seed_row'):
+        seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
+        seed.style(container=False)
+        random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
+        reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
 
-        with gr.Box(elem_id='subseed_show_box'):
-            seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False)
+        with gr.Group(elem_id=target_interface + '_subseed_show_box'):
+            seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
 
     # Components to show/hide based on the 'Extra' checkbox
     seed_extras = []
 
-    with gr.Row(visible=False) as seed_extra_row_1:
+    with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
         seed_extras.append(seed_extra_row_1)
-        with gr.Box():
-            with gr.Row(elem_id='subseed_row'):
-                subseed = gr.Number(label='Variation seed', value=-1)
-                subseed.style(container=False)
-                random_subseed = gr.Button(random_symbol, elem_id='random_subseed')
-                reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed')
-        subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01)
+        subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
+        subseed.style(container=False)
+        random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
+        reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
+        subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
 
-    with gr.Row(visible=False) as seed_extra_row_2:
+    with FormRow(visible=False) as seed_extra_row_2:
         seed_extras.append(seed_extra_row_2)
-        seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0)
-        seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0)
+        seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
+        seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
 
     random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
     random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
@@ -426,6 +313,17 @@ def create_seed_inputs():
     return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
 
 
+
+def connect_clear_prompt(button):
+    """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
+    button.click(
+        _js="clear_prompt",
+        fn=None,
+        inputs=[],
+        outputs=[],
+    )
+
+
 def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
     """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
         (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
@@ -497,22 +395,26 @@ def create_toprow(is_img2img):
                         )
 
         with gr.Column(scale=1, elem_id="roll_col"):
-            roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
             paste = gr.Button(value=paste_symbol, elem_id="paste")
             save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
             prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
-
+            clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
             token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
             token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
 
+            clear_prompt_button.click(
+                fn=lambda *x: x,
+                _js="confirm_clear_prompt",
+                inputs=[prompt, negative_prompt],
+                outputs=[prompt, negative_prompt],
+            )
+
         button_interrogate = None
         button_deepbooru = None
         if is_img2img:
             with gr.Column(scale=1, elem_id="interrogate_col"):
                 button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
-
-                if cmd_opts.deepdanbooru:
-                    button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
+                button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
 
         with gr.Column(scale=1):
             with gr.Row():
@@ -541,7 +443,7 @@ def create_toprow(is_img2img):
                     prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
                     prompt_style2.save_to_config = True
 
-    return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+    return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
 
 
 def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@@ -569,6 +471,9 @@ def apply_setting(key, value):
     if value is None:
         return gr.update()
 
+    if shared.cmd_opts.freeze_settings:
+        return gr.update()
+
     # dont allow model to be swapped when model hash exists in prompt
     if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
         return gr.update()
@@ -586,7 +491,7 @@ def apply_setting(key, value):
         return
 
     valtype = type(opts.data_labels[key].default)
-    oldval = opts.data[key]
+    oldval = opts.data.get(key, None)
     opts.data[key] = valtype(value) if valtype != type(None) else value
     if oldval != value and opts.data_labels[key].onchange is not None:
         opts.data_labels[key].onchange()
@@ -595,30 +500,175 @@ def apply_setting(key, value):
     return value
 
 
-def create_ui(wrap_gradio_gpu_call):
+def update_generation_info(args):
+    generation_info, html_info, img_index = args
+    try:
+        generation_info = json.loads(generation_info)
+        if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+            return html_info
+        return plaintext_to_html(generation_info["infotexts"][img_index])
+    except Exception:
+        pass
+    # if the json parse or anything else fails, just return the old html_info
+    return html_info
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+    def refresh():
+        refresh_method()
+        args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+        for k, v in args.items():
+            setattr(refresh_component, k, v)
+
+        return gr.update(**(args or {}))
+
+    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+    refresh_button.click(
+        fn=refresh,
+        inputs=[],
+        outputs=[refresh_component]
+    )
+    return refresh_button
+
+
+def create_output_panel(tabname, outdir):
+    def open_folder(f):
+        if not os.path.exists(f):
+            print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+            return
+        elif not os.path.isdir(f):
+            print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+            return
+
+        if not shared.cmd_opts.hide_ui_dir_config:
+            path = os.path.normpath(f)
+            if platform.system() == "Windows":
+                os.startfile(path)
+            elif platform.system() == "Darwin":
+                sp.Popen(["open", path])
+            else:
+                sp.Popen(["xdg-open", path])
+
+    with gr.Column(variant='panel'):
+            with gr.Group():
+                result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+
+            generation_info = None
+            with gr.Column():
+                with gr.Row(elem_id=f"image_buttons_{tabname}"):
+                    open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder')
+
+                    if tabname != "extras":
+                        save = gr.Button('Save', elem_id=f'save_{tabname}')
+                        save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
+
+                    buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
+
+                open_folder_button.click(
+                    fn=lambda: open_folder(opts.outdir_samples or outdir),
+                    inputs=[],
+                    outputs=[],
+                )
+
+                if tabname != "extras":
+                    with gr.Row():
+                        download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
+
+                    with gr.Group():
+                        html_info = gr.HTML()
+                        html_log = gr.HTML()
+
+                        generation_info = gr.Textbox(visible=False)
+                        if tabname == 'txt2img' or tabname == 'img2img':
+                            generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+                            generation_info_button.click(
+                                fn=update_generation_info,
+                                _js="(x, y) => [x, y, selected_gallery_index()]",
+                                inputs=[generation_info, html_info],
+                                outputs=[html_info],
+                                preprocess=False
+                            )
+
+                        save.click(
+                            fn=wrap_gradio_call(save_files),
+                            _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
+                            inputs=[
+                                generation_info,
+                                result_gallery,
+                                html_info,
+                                html_info,
+                            ],
+                            outputs=[
+                                download_files,
+                                html_log,
+                            ]
+                        )
+
+                        save_zip.click(
+                            fn=wrap_gradio_call(save_files),
+                            _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
+                            inputs=[
+                                generation_info,
+                                result_gallery,
+                                html_info,
+                                html_info,
+                            ],
+                            outputs=[
+                                download_files,
+                                html_log,
+                            ]
+                        )
+
+                else:
+                    html_info_x = gr.HTML()
+                    html_info = gr.HTML()
+                    html_log = gr.HTML()
+
+                parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+                return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+
+
+def create_sampler_and_steps_selection(choices, tabname):
+    if opts.samplers_in_dropdown:
+        with FormRow(elem_id=f"sampler_selection_{tabname}"):
+            sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+            sampler_index.save_to_config = True
+            steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
+    else:
+        with FormGroup(elem_id=f"sampler_selection_{tabname}"):
+            steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
+            sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+
+    return steps, sampler_index
+
+
+def ordered_ui_categories():
+    user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+
+    for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
+        yield category
+
+
+def create_ui():
     import modules.img2img
     import modules.txt2img
 
-    def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
-        def refresh():
-            refresh_method()
-            args = refreshed_args() if callable(refreshed_args) else refreshed_args
+    reload_javascript()
 
-            for k, v in args.items():
-                setattr(refresh_component, k, v)
+    parameters_copypaste.reset()
 
-            return gr.update(**(args or {}))
-
-        refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
-        refresh_button.click(
-            fn = refresh,
-            inputs = [],
-            outputs = [refresh_component]
-        )
-        return refresh_button
+    modules.scripts.scripts_current = modules.scripts.scripts_txt2img
+    modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
 
     with gr.Blocks(analytics_enabled=False) as txt2img_interface:
-        txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+        txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+
         dummy_component = gr.Label(visible=False)
         txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
 
@@ -632,65 +682,58 @@ def create_ui(wrap_gradio_gpu_call):
                 setup_progressbar(progressbar, txt2img_preview, 'txt2img')
 
         with gr.Row().style(equal_height=False):
-            with gr.Column(variant='panel'):
-                steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
-                sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
+            with gr.Column(variant='panel', elem_id="txt2img_settings"):
+                for category in ordered_ui_categories():
+                    if category == "sampler":
+                        steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
 
-                with gr.Group():
-                    width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
-                    height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+                    elif category == "dimensions":
+                        with FormRow():
+                            with gr.Column(elem_id="txt2img_column_size", scale=4):
+                                width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
+                                height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
 
-                with gr.Row():
-                    restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
-                    tiling = gr.Checkbox(label='Tiling', value=False)
-                    enable_hr = gr.Checkbox(label='Highres. fix', value=False)
+                            if opts.dimensions_and_batch_together:
+                                with gr.Column(elem_id="txt2img_column_batch"):
+                                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+                                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
 
-                with gr.Row(visible=False) as hr_options:
-                    firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
-                    firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
-                    denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
+                    elif category == "cfg":
+                        cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
 
-                with gr.Row(equal_height=True):
-                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
-                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
+                    elif category == "seed":
+                        seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
 
-                cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
+                    elif category == "checkboxes":
+                        with FormRow(elem_id="txt2img_checkboxes"):
+                            restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
+                            tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
+                            enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
 
-                seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
+                    elif category == "hires_fix":
+                        with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options:
+                            hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
+                            hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
+                            denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
 
-                with gr.Group():
-                    custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
+                    elif category == "batch":
+                        if not opts.dimensions_and_batch_together:
+                            with FormRow(elem_id="txt2img_column_batch"):
+                                batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+                                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
 
-            with gr.Column(variant='panel'):
+                    elif category == "scripts":
+                        with FormGroup(elem_id="txt2img_script_container"):
+                            custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
 
-                with gr.Group():
-                    txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
-                    txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
-
-                with gr.Column():
-                    with gr.Row():
-                        save = gr.Button('Save')
-                        send_to_img2img = gr.Button('Send to img2img')
-                        send_to_inpaint = gr.Button('Send to inpaint')
-                        send_to_extras = gr.Button('Send to extras')
-                        button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
-                        open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
-
-                    with gr.Row():
-                        do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
-
-                    with gr.Row():
-                        download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
-
-                    with gr.Group():
-                        html_info = gr.HTML()
-                        generation_info = gr.Textbox(visible=False)
+            txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+            parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
 
             connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
             connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
 
             txt2img_args = dict(
-                fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
+                fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
                 _js="submit",
                 inputs=[
                     txt2img_prompt,
@@ -710,13 +753,15 @@ def create_ui(wrap_gradio_gpu_call):
                     width,
                     enable_hr,
                     denoising_strength,
-                    firstphase_width,
-                    firstphase_height,
+                    hr_scale,
+                    hr_upscaler,
                 ] + custom_inputs,
+
                 outputs=[
                     txt2img_gallery,
                     generation_info,
-                    html_info
+                    html_info,
+                    html_log,
                 ],
                 show_progress=False,
             )
@@ -741,34 +786,6 @@ def create_ui(wrap_gradio_gpu_call):
                 outputs=[hr_options],
             )
 
-            save.click(
-                fn=wrap_gradio_call(save_files),
-                _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
-                inputs=[
-                    generation_info,
-                    txt2img_gallery,
-                    do_make_zip,
-                    html_info,
-                ],
-                outputs=[
-                    download_files,
-                    html_info,
-                    html_info,
-                    html_info,
-                ]
-            )
-
-            roll.click(
-                fn=roll_artist,
-                _js="update_txt2img_tokens",
-                inputs=[
-                    txt2img_prompt,
-                ],
-                outputs=[
-                    txt2img_prompt,
-                ]
-            )
-
             txt2img_paste_fields = [
                 (txt2img_prompt, "Prompt"),
                 (txt2img_negative_prompt, "Negative prompt"),
@@ -787,9 +804,11 @@ def create_ui(wrap_gradio_gpu_call):
                 (denoising_strength, "Denoising strength"),
                 (enable_hr, lambda d: "Denoising strength" in d),
                 (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
-                (firstphase_width, "First pass size-1"),
-                (firstphase_height, "First pass size-2"),
+                (hr_scale, "Hires upscale"),
+                (hr_upscaler, "Hires upscaler"),
+                *modules.scripts.scripts_txt2img.infotext_fields
             ]
+            parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
 
             txt2img_preview_params = [
                 txt2img_prompt,
@@ -802,10 +821,13 @@ def create_ui(wrap_gradio_gpu_call):
                 height,
             ]
 
-            token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
+            token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
+
+    modules.scripts.scripts_current = modules.scripts.scripts_img2img
+    modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
 
     with gr.Blocks(analytics_enabled=False) as img2img_interface:
-        img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
+        img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
 
         with gr.Row(elem_id='img2img_progress_row'):
             img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -818,88 +840,98 @@ def create_ui(wrap_gradio_gpu_call):
                 img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
                 setup_progressbar(progressbar, img2img_preview, 'img2img')
 
-        with gr.Row().style(equal_height=False):
-            with gr.Column(variant='panel'):
+        with FormRow().style(equal_height=False):
+            with gr.Column(variant='panel', elem_id="img2img_settings"):
 
                 with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
-                    with gr.TabItem('img2img', id='img2img'):
-                        init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
+                    with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
+                        init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
 
-                    with gr.TabItem('Inpaint', id='inpaint'):
-                        init_img_with_mask = gr.Image(label="Image for inpainting with mask",  show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
+                    with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
+                        init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
+                        init_img_with_mask_orig = gr.State(None)
+
+                        use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
+                        if use_color_sketch:
+                            def update_orig(image, state):
+                                if image is not None:
+                                    same_size = state is not None and state.size == image.size
+                                    has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+                                    edited = same_size and has_exact_match
+                                    return image if not edited or state is None else state
+
+                            init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
 
                         init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
                         init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
 
-                        mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
+                        with FormRow():
+                            mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+                            mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
 
-                        with gr.Row():
-                            mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
-                            inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
+                        with FormRow():
+                            mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
+                            inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
 
-                        inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
+                        with FormRow():
+                            inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
 
-                        with gr.Row():
-                            inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
-                            inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32)
+                        with FormRow():
+                            with gr.Column():
+                                inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
 
-                    with gr.TabItem('Batch img2img', id='batch'):
+                            with gr.Column(scale=4):
+                                inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+                    with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
                         hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
                         gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
-                        img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
-                        img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
+                        img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
+                        img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
 
-                with gr.Row():
-                    resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
+                with FormRow():
+                    resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
 
-                steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
-                sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
+                for category in ordered_ui_categories():
+                    if category == "sampler":
+                        steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
 
-                with gr.Group():
-                    width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
-                    height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+                    elif category == "dimensions":
+                        with FormRow():
+                            with gr.Column(elem_id="img2img_column_size", scale=4):
+                                width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
+                                height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
 
-                with gr.Row():
-                    restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
-                    tiling = gr.Checkbox(label='Tiling', value=False)
+                            if opts.dimensions_and_batch_together:
+                                with gr.Column(elem_id="img2img_column_batch"):
+                                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+                                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
 
-                with gr.Row():
-                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
-                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
+                    elif category == "cfg":
+                        with FormGroup():
+                            cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
+                            denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
 
-                with gr.Group():
-                    cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
-                    denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
+                    elif category == "seed":
+                        seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
 
-                seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
+                    elif category == "checkboxes":
+                        with FormRow(elem_id="img2img_checkboxes"):
+                            restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
+                            tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
 
-                with gr.Group():
-                    custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
+                    elif category == "batch":
+                        if not opts.dimensions_and_batch_together:
+                            with FormRow(elem_id="img2img_column_batch"):
+                                batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+                                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
 
-            with gr.Column(variant='panel'):
+                    elif category == "scripts":
+                        with FormGroup(elem_id="img2img_script_container"):
+                            custom_inputs = modules.scripts.scripts_img2img.setup_ui()
 
-                with gr.Group():
-                    img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
-                    img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
-
-                with gr.Column():
-                    with gr.Row():
-                        save = gr.Button('Save')
-                        img2img_send_to_img2img = gr.Button('Send to img2img')
-                        img2img_send_to_inpaint = gr.Button('Send to inpaint')
-                        img2img_send_to_extras = gr.Button('Send to extras')
-                        button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
-                        open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
-
-                    with gr.Row():
-                        do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
-
-                    with gr.Row():
-                        download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
-
-                    with gr.Group():
-                        html_info = gr.HTML()
-                        generation_info = gr.Textbox(visible=False)
+            img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
+            parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
 
             connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
             connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -930,7 +962,7 @@ def create_ui(wrap_gradio_gpu_call):
             )
 
             img2img_args = dict(
-                fn=wrap_gradio_gpu_call(modules.img2img.img2img),
+                fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
                 _js="submit_img2img",
                 inputs=[
                     dummy_component,
@@ -940,12 +972,14 @@ def create_ui(wrap_gradio_gpu_call):
                     img2img_prompt_style2,
                     init_img,
                     init_img_with_mask,
+                    init_img_with_mask_orig,
                     init_img_inpaint,
                     init_mask_inpaint,
                     mask_mode,
                     steps,
                     sampler_index,
                     mask_blur,
+                    mask_alpha,
                     inpainting_fill,
                     restore_faces,
                     tiling,
@@ -967,7 +1001,8 @@ def create_ui(wrap_gradio_gpu_call):
                 outputs=[
                     img2img_gallery,
                     generation_info,
-                    html_info
+                    html_info,
+                    html_log,
                 ],
                 show_progress=False,
             )
@@ -981,39 +1016,10 @@ def create_ui(wrap_gradio_gpu_call):
                 outputs=[img2img_prompt],
             )
 
-            if cmd_opts.deepdanbooru:
-                img2img_deepbooru.click(
-                    fn=interrogate_deepbooru,
-                    inputs=[init_img],
-                    outputs=[img2img_prompt],
-                )
-
-            save.click(
-                fn=wrap_gradio_call(save_files),
-                _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
-                inputs=[
-                    generation_info,
-                    img2img_gallery,
-                    do_make_zip,
-                    html_info,
-                ],
-                outputs=[
-                    download_files,
-                    html_info,
-                    html_info,
-                    html_info,
-                ]
-            )
-
-            roll.click(
-                fn=roll_artist,
-                _js="update_img2img_tokens",
-                inputs=[
-                    img2img_prompt,
-                ],
-                outputs=[
-                    img2img_prompt,
-                ]
+            img2img_deepbooru.click(
+                fn=interrogate_deepbooru,
+                inputs=[init_img],
+                outputs=[img2img_prompt],
             )
 
             prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
@@ -1038,6 +1044,8 @@ def create_ui(wrap_gradio_gpu_call):
                     outputs=[prompt, negative_prompt, style1, style2],
                 )
 
+            token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+
             img2img_paste_fields = [
                 (img2img_prompt, "Prompt"),
                 (img2img_negative_prompt, "Negative prompt"),
@@ -1054,66 +1062,62 @@ def create_ui(wrap_gradio_gpu_call):
                 (seed_resize_from_w, "Seed resize from-1"),
                 (seed_resize_from_h, "Seed resize from-2"),
                 (denoising_strength, "Denoising strength"),
+                (mask_blur, "Mask blur"),
+                *modules.scripts.scripts_img2img.infotext_fields
             ]
-            token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+            parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
+            parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
+
+    modules.scripts.scripts_current = None
 
     with gr.Blocks(analytics_enabled=False) as extras_interface:
         with gr.Row().style(equal_height=False):
             with gr.Column(variant='panel'):
                 with gr.Tabs(elem_id="mode_extras"):
-                    with gr.TabItem('Single Image'):
-                        extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
+                    with gr.TabItem('Single Image', elem_id="extras_single_tab"):
+                        extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
 
-                    with gr.TabItem('Batch Process'):
-                        image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
+                    with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"):
+                        image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
 
-                    with gr.TabItem('Batch from Directory'):
-                        extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs,
-                            placeholder="A directory on the same machine where the server is running."
-                        )
-                        extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
-                            placeholder="Leave blank to save images to the default path."
-                        )
-                        show_extras_results = gr.Checkbox(label='Show result images', value=True)
+                    with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"):
+                        extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
+                        extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
+                        show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
+
+                submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
 
                 with gr.Tabs(elem_id="extras_resize_mode"):
-                    with gr.TabItem('Scale by'):
-                        upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
-                    with gr.TabItem('Scale to'):
+                    with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"):
+                        upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
+                    with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"):
                         with gr.Group():
                             with gr.Row():
-                                upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
-                                upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
-                            upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
+                                upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
+                                upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
+                            upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
 
                 with gr.Group():
                     extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
 
                 with gr.Group():
                     extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
-                    extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
+                    extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility")
 
                 with gr.Group():
-                    gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan)
+                    gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility")
 
                 with gr.Group():
-                    codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
-                    codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
+                    codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility")
+                    codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight")
 
-                submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
-            with gr.Column(variant='panel'):
-                result_images = gr.Gallery(label="Result", show_label=False)
-                html_info_x = gr.HTML()
-                html_info = gr.HTML()
-                extras_send_to_img2img = gr.Button('Send to img2img')
-                extras_send_to_inpaint = gr.Button('Send to inpaint')
-                button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
-                open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
+                with gr.Group():
+                    upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix")
 
+            result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
 
         submit.click(
-            fn=wrap_gradio_gpu_call(modules.extras.run_extras),
+            fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']),
             _js="get_extras_tab_index",
             inputs=[
                 dummy_component,
@@ -1133,6 +1137,7 @@ def create_ui(wrap_gradio_gpu_call):
                 extras_upscaler_1,
                 extras_upscaler_2,
                 extras_upscaler_2_visibility,
+                upscale_before_face_fix,
             ],
             outputs=[
                 result_images,
@@ -1140,19 +1145,11 @@ def create_ui(wrap_gradio_gpu_call):
                 html_info,
             ]
         )
+        parameters_copypaste.add_paste_fields("extras", extras_image, None)
 
-        extras_send_to_img2img.click(
-            fn=lambda x: image_from_url_text(x),
-            _js="extract_image_from_gallery_img2img",
-            inputs=[result_images],
-            outputs=[init_img],
-        )
-
-        extras_send_to_inpaint.click(
-            fn=lambda x: image_from_url_text(x),
-            _js="extract_image_from_gallery_inpaint",
-            inputs=[result_images],
-            outputs=[init_img_with_mask],
+        extras_image.change(
+            fn=modules.extras.clear_cache,
+            inputs=[], outputs=[]
         )
 
     with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
@@ -1162,48 +1159,47 @@ def create_ui(wrap_gradio_gpu_call):
 
             with gr.Column(variant='panel'):
                 html = gr.HTML()
-                generation_info = gr.Textbox(visible=False)
+                generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
                 html2 = gr.HTML()
-
                 with gr.Row():
-                    pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
-                    pnginfo_send_to_img2img = gr.Button('Send to img2img')
+                    buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
+                parameters_copypaste.bind_buttons(buttons, image, generation_info)
 
         image.change(
             fn=wrap_gradio_call(modules.extras.run_pnginfo),
             inputs=[image],
             outputs=[html, generation_info, html2],
         )
-    #images history
-    images_history_switch_dict = {
-        "fn":modules.generation_parameters_copypaste.connect_paste,
-        "t2i":txt2img_paste_fields,
-        "i2i":img2img_paste_fields
-    }
 
-    images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
-
-    with gr.Blocks() as modelmerger_interface:
+    with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
         with gr.Row().style(equal_height=False):
             with gr.Column(variant='panel'):
                 gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
 
                 with gr.Row():
                     primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
+                    create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
+
                     secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
+                    create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
+
                     tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
-                custom_name = gr.Textbox(label="Custom Name (Optional)")
-                interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
-                interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
-                save_as_half = gr.Checkbox(value=False, label="Save as float16")
+                    create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
+
+                custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
+                interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
+                interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+
+                with gr.Row():
+                    checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
+                    save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
+
                 modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
 
             with gr.Column(variant='panel'):
                 submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
 
-    sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
-
-    with gr.Blocks() as train_interface:
+    with gr.Blocks(analytics_enabled=False) as train_interface:
         with gr.Row().style(equal_height=False):
             gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
 
@@ -1211,74 +1207,118 @@ def create_ui(wrap_gradio_gpu_call):
             with gr.Tabs(elem_id="train_tabs"):
 
                 with gr.Tab(label="Create embedding"):
-                    new_embedding_name = gr.Textbox(label="Name")
-                    initialization_text = gr.Textbox(label="Initialization text", value="*")
-                    nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+                    new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
+                    initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
+                    nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
+                    overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
 
                     with gr.Row():
                         with gr.Column(scale=3):
                             gr.HTML(value="")
 
                         with gr.Column():
-                            create_embedding = gr.Button(value="Create embedding", variant='primary')
+                            create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
 
                 with gr.Tab(label="Create hypernetwork"):
-                    new_hypernetwork_name = gr.Textbox(label="Name")
-                    new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
-                    new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
-                    new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
+                    new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
+                    new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
+                    new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
+                    new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
+                    new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
+                    new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
+                    new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
+                    overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
 
                     with gr.Row():
                         with gr.Column(scale=3):
                             gr.HTML(value="")
 
                         with gr.Column():
-                            create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
+                            create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
 
                 with gr.Tab(label="Preprocess images"):
-                    process_src = gr.Textbox(label='Source directory')
-                    process_dst = gr.Textbox(label='Destination directory')
-                    process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
-                    process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+                    process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
+                    process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
+                    process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
+                    process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
+                    preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
 
                     with gr.Row():
-                        process_flip = gr.Checkbox(label='Create flipped copies')
-                        process_split = gr.Checkbox(label='Split oversized images into two')
-                        process_caption = gr.Checkbox(label='Use BLIP for caption')
-                        process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
+                        process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
+                        process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
+                        process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
+                        process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
+                        process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
+
+                    with gr.Row(visible=False) as process_split_extra_row:
+                        process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
+                        process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
+
+                    with gr.Row(visible=False) as process_focal_crop_row:
+                        process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
+                        process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
+                        process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
+                        process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
 
                     with gr.Row():
                         with gr.Column(scale=3):
                             gr.HTML(value="")
 
                         with gr.Column():
-                            run_preprocess = gr.Button(value="Preprocess", variant='primary')
+                            with gr.Row():
+                                interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
+                            run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
+
+                    process_split.change(
+                        fn=lambda show: gr_show(show),
+                        inputs=[process_split],
+                        outputs=[process_split_extra_row],
+                    )
+
+                    process_focal_crop.change(
+                        fn=lambda show: gr_show(show),
+                        inputs=[process_focal_crop],
+                        outputs=[process_focal_crop_row],
+                    )
 
                 with gr.Tab(label="Train"):
-                    gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
+                    gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
                     with gr.Row():
                         train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
                         create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
                     with gr.Row():
                         train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
                         create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
-                    learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
-                    batch_size = gr.Number(label='Batch size', value=1, precision=0)
-                    dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
-                    log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
-                    template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
-                    training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
-                    training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
-                    steps = gr.Number(label='Max steps', value=100000, precision=0)
-                    create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
-                    save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
-                    save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
-                    preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
+                    with gr.Row():
+                        embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
+                        hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
+
+                    batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
+                    gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
+                    dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
+                    log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
+                    template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
+                    training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
+                    training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+                    steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
+                    create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
+                    save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
+                    save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
+                    preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
+                    with gr.Row():
+                        shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
+                        tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
+                    with gr.Row():
+                        latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
 
                     with gr.Row():
-                        interrupt_training = gr.Button(value="Interrupt")
-                        train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
-                        train_embedding = gr.Button(value="Train Embedding", variant='primary')
+                        interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
+                        train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
+                        train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
+
+                params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
+
+                script_callbacks.ui_train_tabs_callback(params)
 
             with gr.Column():
                 progressbar = gr.HTML(elem_id="ti_progressbar")
@@ -1296,6 +1336,7 @@ def create_ui(wrap_gradio_gpu_call):
                 new_embedding_name,
                 initialization_text,
                 nvpt,
+                overwrite_old_embedding,
             ],
             outputs=[
                 train_embedding_name,
@@ -1309,8 +1350,12 @@ def create_ui(wrap_gradio_gpu_call):
             inputs=[
                 new_hypernetwork_name,
                 new_hypernetwork_sizes,
+                overwrite_old_hypernetwork,
                 new_hypernetwork_layer_structure,
+                new_hypernetwork_activation_func,
+                new_hypernetwork_initialization_option,
                 new_hypernetwork_add_layer_norm,
+                new_hypernetwork_use_dropout
             ],
             outputs=[
                 train_hypernetwork_name,
@@ -1327,10 +1372,18 @@ def create_ui(wrap_gradio_gpu_call):
                 process_dst,
                 process_width,
                 process_height,
+                preprocess_txt_action,
                 process_flip,
                 process_split,
                 process_caption,
-                process_caption_deepbooru
+                process_caption_deepbooru,
+                process_split_threshold,
+                process_overlap_ratio,
+                process_focal_crop,
+                process_focal_crop_face_weight,
+                process_focal_crop_entropy_weight,
+                process_focal_crop_edges_weight,
+                process_focal_crop_debug,
             ],
             outputs=[
                 ti_output,
@@ -1343,13 +1396,17 @@ def create_ui(wrap_gradio_gpu_call):
             _js="start_training_textual_inversion",
             inputs=[
                 train_embedding_name,
-                learn_rate,
+                embedding_learn_rate,
                 batch_size,
+                gradient_step,
                 dataset_directory,
                 log_directory,
                 training_width,
                 training_height,
                 steps,
+                shuffle_tags,
+                tag_drop_out,
+                latent_sampling_method,
                 create_image_every,
                 save_embedding_every,
                 template_file,
@@ -1368,13 +1425,17 @@ def create_ui(wrap_gradio_gpu_call):
             _js="start_training_textual_inversion",
             inputs=[
                 train_hypernetwork_name,
-                learn_rate,
+                hypernetwork_learn_rate,
                 batch_size,
+                gradient_step,
                 dataset_directory,
                 log_directory,
                 training_width,
                 training_height,
                 steps,
+                shuffle_tags,
+                tag_drop_out,
+                latent_sampling_method,
                 create_image_every,
                 save_embedding_every,
                 template_file,
@@ -1393,6 +1454,12 @@ def create_ui(wrap_gradio_gpu_call):
             outputs=[],
         )
 
+        interrupt_preprocessing.click(
+            fn=lambda: shared.state.interrupt(),
+            inputs=[],
+            outputs=[],
+        )
+
     def create_setting_component(key, is_quicksettings=False):
         def fun():
             return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1417,142 +1484,108 @@ def create_ui(wrap_gradio_gpu_call):
 
         if info.refresh is not None:
             if is_quicksettings:
-                res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+                res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
                 create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
             else:
-                with gr.Row(variant="compact"):
-                    res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+                with FormRow():
+                    res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
                     create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
         else:
-            res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
-
+            res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
 
         return res
 
     components = []
     component_dict = {}
 
-    def open_folder(f):
-        if not os.path.exists(f):
-            print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
-            return
-        elif not os.path.isdir(f):
-            print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
-            return
-
-        if not shared.cmd_opts.hide_ui_dir_config:
-            path = os.path.normpath(f)
-            if platform.system() == "Windows":
-                os.startfile(path)
-            elif platform.system() == "Darwin":
-                sp.Popen(["open", path])
-            else:
-                sp.Popen(["xdg-open", path])
+    script_callbacks.ui_settings_callback()
+    opts.reorder()
 
     def run_settings(*args):
-        changed = 0
+        changed = []
 
         for key, value, comp in zip(opts.data_labels.keys(), args, components):
-            if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
-                return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
+            assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
 
         for key, value, comp in zip(opts.data_labels.keys(), args, components):
             if comp == dummy_component:
                 continue
 
-            comp_args = opts.data_labels[key].component_args
-            if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
-                continue
+            if opts.set(key, value):
+                changed.append(key)
 
-            if cmd_opts.hide_ui_dir_config and key in restricted_opts:
-                continue
-
-            oldval = opts.data.get(key, None)
-            opts.data[key] = value
-
-            if oldval != value:
-                if opts.data_labels[key].onchange is not None:
-                    opts.data_labels[key].onchange()
-
-                changed += 1
-
-        opts.save(shared.config_filename)
-
-        return f'{changed} settings changed.', opts.dumpjson()
+        try:
+            opts.save(shared.config_filename)
+        except RuntimeError:
+            return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
+        return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
 
     def run_settings_single(value, key):
         if not opts.same_type(value, opts.data_labels[key].default):
             return gr.update(visible=True), opts.dumpjson()
 
-        if cmd_opts.hide_ui_dir_config and key in restricted_opts:
-            return gr.update(value=oldval), opts.dumpjson()
-
-        oldval = opts.data.get(key, None)
-        opts.data[key] = value
-
-        if oldval != value:
-            if opts.data_labels[key].onchange is not None:
-                opts.data_labels[key].onchange()
+        if not opts.set(key, value):
+            return gr.update(value=getattr(opts, key)), opts.dumpjson()
 
         opts.save(shared.config_filename)
 
         return gr.update(value=value), opts.dumpjson()
 
     with gr.Blocks(analytics_enabled=False) as settings_interface:
-        settings_submit = gr.Button(value="Apply settings", variant='primary')
-        result = gr.HTML()
+        with gr.Row():
+            with gr.Column(scale=6):
+                settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+            with gr.Column():
+                restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
 
-        settings_cols = 3
-        items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
+        result = gr.HTML(elem_id="settings_result")
 
         quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
-        quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
+        quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
 
         quicksettings_list = []
 
-        cols_displayed = 0
-        items_displayed = 0
         previous_section = None
-        column = None
-        with gr.Row(elem_id="settings").style(equal_height=False):
+        current_tab = None
+        with gr.Tabs(elem_id="settings"):
             for i, (k, item) in enumerate(opts.data_labels.items()):
+                section_must_be_skipped = item.section[0] is None
 
-                if previous_section != item.section:
-                    if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
-                        if column is not None:
-                            column.__exit__()
+                if previous_section != item.section and not section_must_be_skipped:
+                    elem_id, text = item.section
 
-                        column = gr.Column(variant='panel')
-                        column.__enter__()
+                    if current_tab is not None:
+                        current_tab.__exit__()
 
-                        items_displayed = 0
-                        cols_displayed += 1
+                    current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
+                    current_tab.__enter__()
 
                     previous_section = item.section
 
-                    gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
-
-                if k in quicksettings_names:
+                if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
                     quicksettings_list.append((i, k, item))
                     components.append(dummy_component)
+                elif section_must_be_skipped:
+                    components.append(dummy_component)
                 else:
                     component = create_setting_component(k)
                     component_dict[k] = component
                     components.append(component)
-                    items_displayed += 1
 
-        with gr.Row():
-            request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
-            download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+            if current_tab is not None:
+                current_tab.__exit__()
 
-        with gr.Row():
-            reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
-            restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
+            with gr.TabItem("Actions"):
+                request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+                download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+                reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+
+            if os.path.exists("html/licenses.html"):
+                with open("html/licenses.html", encoding="utf8") as file:
+                    with gr.TabItem("Licenses"):
+                        gr.HTML(file.read(), elem_id="licenses")
+
+            gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
 
         request_notifications.click(
             fn=lambda: None,
@@ -1570,58 +1603,64 @@ Requested path was: {f}
 
         def reload_scripts():
             modules.scripts.reload_script_body_only()
-            reload_javascript() # need to refresh the html page
+            reload_javascript()  # need to refresh the html page
 
         reload_script_bodies.click(
             fn=reload_scripts,
             inputs=[],
-            outputs=[],
-            _js='function(){}'
+            outputs=[]
         )
 
         def request_restart():
             shared.state.interrupt()
-            settings_interface.gradio_ref.do_restart = True
+            shared.state.need_restart = True
 
         restart_gradio.click(
             fn=request_restart,
+            _js='restart_reload',
             inputs=[],
             outputs=[],
-            _js='function(){restart_reload()}'
         )
 
-        if column is not None:
-            column.__exit__()
-
     interfaces = [
         (txt2img_interface, "txt2img", "txt2img"),
         (img2img_interface, "img2img", "img2img"),
         (extras_interface, "Extras", "extras"),
         (pnginfo_interface, "PNG Info", "pnginfo"),
-        (images_history, "History", "images_history"),
         (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
         (train_interface, "Train", "ti"),
-        (settings_interface, "Settings", "settings"),
     ]
 
-    with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
-        css = file.read()
+    css = ""
+
+    for cssfile in modules.scripts.list_files_with_name("style.css"):
+        if not os.path.isfile(cssfile):
+            continue
+
+        with open(cssfile, "r", encoding="utf8") as file:
+            css += file.read() + "\n"
 
     if os.path.exists(os.path.join(script_path, "user.css")):
         with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
-            usercss = file.read()
-            css += usercss
+            css += file.read() + "\n"
 
     if not cmd_opts.no_progressbar_hiding:
         css += css_hide_progressbar
 
+    interfaces += script_callbacks.ui_tabs_callback()
+    interfaces += [(settings_interface, "Settings", "settings")]
+
+    extensions_interface = ui_extensions.create_ui()
+    interfaces += [(extensions_interface, "Extensions", "extensions")]
+
     with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
         with gr.Row(elem_id="quicksettings"):
-            for i, k, item in quicksettings_list:
+            for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
                 component = create_setting_component(k, is_quicksettings=True)
                 component_dict[k] = component
 
-        settings_interface.gradio_ref = demo
+        parameters_copypaste.integrate_settings_paste_fields(component_dict)
+        parameters_copypaste.run_bind()
 
         with gr.Tabs(elem_id="tabs") as tabs:
             for interface, label, ifid in interfaces:
@@ -1631,11 +1670,15 @@ Requested path was: {f}
         if os.path.exists(os.path.join(script_path, "notification.mp3")):
             audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
 
+        if os.path.exists("html/footer.html"):
+            with open("html/footer.html", encoding="utf8") as file:
+                gr.HTML(file.read(), elem_id="footer")
+
         text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
         settings_submit.click(
-            fn=run_settings,
+            fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
             inputs=components,
-            outputs=[result, text_settings],
+            outputs=[text_settings, result],
         )
 
         for i, k, item in quicksettings_list:
@@ -1647,6 +1690,17 @@ Requested path was: {f}
                 outputs=[component, text_settings],
             )
 
+        component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
+
+        def get_settings_values():
+            return [getattr(opts, key) for key in component_keys]
+
+        demo.load(
+            fn=get_settings_values,
+            inputs=[],
+            outputs=[component_dict[k] for k in component_keys],
+        )
+
         def modelmerger(*args):
             try:
                 results = modules.extras.run_modelmerger(*args)
@@ -1654,7 +1708,7 @@ Requested path was: {f}
                 print("Error loading/saving model file:", file=sys.stderr)
                 print(traceback.format_exc(), file=sys.stderr)
                 modules.sd_models.list_models()  # to remove the potentially missing models from the list
-                return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
+                return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
             return results
 
         modelmerger_merge.click(
@@ -1667,6 +1721,7 @@ Requested path was: {f}
                 interp_amount,
                 save_as_half,
                 custom_name,
+                checkpoint_format,
             ],
             outputs=[
                 submit_result,
@@ -1676,85 +1731,6 @@ Requested path was: {f}
                 component_dict['sd_model_checkpoint'],
             ]
         )
-        paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
-        txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
-        img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
-        send_to_img2img.click(
-            fn=lambda img, *args: (image_from_url_text(img),*args),
-            _js="(gallery, ...args) => [extract_image_from_gallery_img2img(gallery), ...args]",
-            inputs=[txt2img_gallery] + txt2img_fields,
-            outputs=[init_img] + img2img_fields,
-        )
-
-        send_to_inpaint.click(
-            fn=lambda x, *args: (image_from_url_text(x), *args),
-            _js="(gallery, ...args) => [extract_image_from_gallery_inpaint(gallery), ...args]",
-            inputs=[txt2img_gallery] + txt2img_fields,
-            outputs=[init_img_with_mask] + img2img_fields,
-        )
-
-        img2img_send_to_img2img.click(
-            fn=lambda x: image_from_url_text(x),
-            _js="extract_image_from_gallery_img2img",
-            inputs=[img2img_gallery],
-            outputs=[init_img],
-        )
-
-        img2img_send_to_inpaint.click(
-            fn=lambda x: image_from_url_text(x),
-            _js="extract_image_from_gallery_inpaint",
-            inputs=[img2img_gallery],
-            outputs=[init_img_with_mask],
-        )
-
-        send_to_extras.click(
-            fn=lambda x: image_from_url_text(x),
-            _js="extract_image_from_gallery_extras",
-            inputs=[txt2img_gallery],
-            outputs=[extras_image],
-        )
-
-        open_txt2img_folder.click(
-            fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
-            inputs=[],
-            outputs=[],
-        )
-
-        open_img2img_folder.click(
-            fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
-            inputs=[],
-            outputs=[],
-        )
-
-        open_extras_folder.click(
-            fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
-            inputs=[],
-            outputs=[],
-        )
-
-        img2img_send_to_extras.click(
-            fn=lambda x: image_from_url_text(x),
-            _js="extract_image_from_gallery_extras",
-            inputs=[img2img_gallery],
-            outputs=[extras_image],
-        )
-
-        settings_map = {
-            'sd_hypernetwork': 'Hypernet',
-            'CLIP_stop_at_last_layers': 'Clip skip',
-            'sd_model_checkpoint': 'Model hash',
-        }
-
-        settings_paste_fields = [
-            (component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
-            for k, v in settings_map.items()
-        ]
-
-        modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
-        modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)
-
-        modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
-        modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
 
     ui_config_file = cmd_opts.ui_config_file
     ui_settings = {}
@@ -1774,7 +1750,7 @@ Requested path was: {f}
         def apply_field(obj, field, condition=None, init_field=None):
             key = path + "/" + field
 
-            if getattr(obj,'custom_script_source',None) is not None:
+            if getattr(obj, 'custom_script_source', None) is not None:
               key = 'customscript/' + obj.custom_script_source + '/' + key
 
             if getattr(obj, 'do_not_save_to_config', False):
@@ -1829,13 +1805,14 @@ Requested path was: {f}
     return demo
 
 
-def load_javascript(raw_response):
+def reload_javascript():
     with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
         javascript = f'<script>{jsfile.read()}</script>'
 
-    jsdir = os.path.join(script_path, "javascript")
-    for filename in sorted(os.listdir(jsdir)):
-        with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
+    scripts_list = modules.scripts.list_scripts("javascript", ".js")
+
+    for basedir, filename, path in scripts_list:
+        with open(path, "r", encoding="utf8") as jsfile:
             javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
 
     if cmd_opts.theme is not None:
@@ -1844,7 +1821,7 @@ def load_javascript(raw_response):
     javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
 
     def template_response(*args, **kwargs):
-        res = raw_response(*args, **kwargs)
+        res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
         res.body = res.body.replace(
             b'</head>', f'{javascript}</head>'.encode("utf8"))
         res.init_headers()
@@ -1853,6 +1830,5 @@ def load_javascript(raw_response):
     gradio.routes.templates.TemplateResponse = template_response
 
 
-reload_javascript = partial(load_javascript,
-                            gradio.routes.templates.TemplateResponse)
-reload_javascript()
+if not hasattr(shared, 'GradioTemplateResponseOriginal'):
+    shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
diff --git a/modules/ui_components.py b/modules/ui_components.py
new file mode 100644
index 00000000..91eb0e3d
--- /dev/null
+++ b/modules/ui_components.py
@@ -0,0 +1,25 @@
+import gradio as gr
+
+
+class ToolButton(gr.Button, gr.components.FormComponent):
+    """Small button with single emoji as text, fits inside gradio forms"""
+
+    def __init__(self, **kwargs):
+        super().__init__(variant="tool", **kwargs)
+
+    def get_block_name(self):
+        return "button"
+
+
+class FormRow(gr.Row, gr.components.FormComponent):
+    """Same as gr.Row but fits inside gradio forms"""
+
+    def get_block_name(self):
+        return "row"
+
+
+class FormGroup(gr.Group, gr.components.FormComponent):
+    """Same as gr.Row but fits inside gradio forms"""
+
+    def get_block_name(self):
+        return "group"
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
new file mode 100644
index 00000000..eec9586f
--- /dev/null
+++ b/modules/ui_extensions.py
@@ -0,0 +1,328 @@
+import json
+import os.path
+import shutil
+import sys
+import time
+import traceback
+
+import git
+
+import gradio as gr
+import html
+import shutil
+import errno
+
+from modules import extensions, shared, paths
+
+
+available_extensions = {"extensions": []}
+
+
+def check_access():
+    assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
+
+
+def apply_and_restart(disable_list, update_list):
+    check_access()
+
+    disabled = json.loads(disable_list)
+    assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
+
+    update = json.loads(update_list)
+    assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
+
+    update = set(update)
+
+    for ext in extensions.extensions:
+        if ext.name not in update:
+            continue
+
+        try:
+            ext.fetch_and_reset_hard()
+        except Exception:
+            print(f"Error getting updates for {ext.name}:", file=sys.stderr)
+            print(traceback.format_exc(), file=sys.stderr)
+
+    shared.opts.disabled_extensions = disabled
+    shared.opts.save(shared.config_filename)
+
+    shared.state.interrupt()
+    shared.state.need_restart = True
+
+
+def check_updates():
+    check_access()
+
+    for ext in extensions.extensions:
+        if ext.remote is None:
+            continue
+
+        try:
+            ext.check_updates()
+        except Exception:
+            print(f"Error checking updates for {ext.name}:", file=sys.stderr)
+            print(traceback.format_exc(), file=sys.stderr)
+
+    return extension_table()
+
+
+def extension_table():
+    code = f"""<!-- {time.time()} -->
+    <table id="extensions">
+        <thead>
+            <tr>
+                <th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
+                <th>URL</th>
+                <th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
+            </tr>
+        </thead>
+        <tbody>
+    """
+
+    for ext in extensions.extensions:
+        remote = ""
+        if ext.is_builtin:
+            remote = "built-in"
+        elif ext.remote:
+            remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
+
+        if ext.can_update:
+            ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
+        else:
+            ext_status = ext.status
+
+        code += f"""
+            <tr>
+                <td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
+                <td>{remote}</td>
+                <td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
+            </tr>
+    """
+
+    code += """
+        </tbody>
+    </table>
+    """
+
+    return code
+
+
+def normalize_git_url(url):
+    if url is None:
+        return ""
+
+    url = url.replace(".git", "")
+    return url
+
+
+def install_extension_from_url(dirname, url):
+    check_access()
+
+    assert url, 'No URL specified'
+
+    if dirname is None or dirname == "":
+        *parts, last_part = url.split('/')
+        last_part = normalize_git_url(last_part)
+
+        dirname = last_part
+
+    target_dir = os.path.join(extensions.extensions_dir, dirname)
+    assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
+
+    normalized_url = normalize_git_url(url)
+    assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
+
+    tmpdir = os.path.join(paths.script_path, "tmp", dirname)
+
+    try:
+        shutil.rmtree(tmpdir, True)
+
+        repo = git.Repo.clone_from(url, tmpdir)
+        repo.remote().fetch()
+
+        try:
+            os.rename(tmpdir, target_dir)
+        except OSError as err:
+            # TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
+            # Shouldn't cause any new issues at least but we probably want to handle it there too.
+            if err.errno == errno.EXDEV:
+                # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
+                # Since we can't use a rename, do the slower but more versitile shutil.move()
+                shutil.move(tmpdir, target_dir)
+            else:
+                # Something else, not enough free space, permissions, etc.  rethrow it so that it gets handled.
+                raise(err)
+
+        import launch
+        launch.run_extension_installer(target_dir)
+
+        extensions.list_extensions()
+        return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
+    finally:
+        shutil.rmtree(tmpdir, True)
+
+
+def install_extension_from_index(url, hide_tags):
+    ext_table, message = install_extension_from_url(None, url)
+
+    code, _ = refresh_available_extensions_from_data(hide_tags)
+
+    return code, ext_table, message
+
+
+def refresh_available_extensions(url, hide_tags):
+    global available_extensions
+
+    import urllib.request
+    with urllib.request.urlopen(url) as response:
+        text = response.read()
+
+    available_extensions = json.loads(text)
+
+    code, tags = refresh_available_extensions_from_data(hide_tags)
+
+    return url, code, gr.CheckboxGroup.update(choices=tags), ''
+
+
+def refresh_available_extensions_for_tags(hide_tags):
+    code, _ = refresh_available_extensions_from_data(hide_tags)
+
+    return code, ''
+
+
+def refresh_available_extensions_from_data(hide_tags):
+    extlist = available_extensions["extensions"]
+    installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
+
+    tags = available_extensions.get("tags", {})
+    tags_to_hide = set(hide_tags)
+    hidden = 0
+
+    code = f"""<!-- {time.time()} -->
+    <table id="available_extensions">
+        <thead>
+            <tr>
+                <th>Extension</th>
+                <th>Description</th>
+                <th>Action</th>
+            </tr>
+        </thead>
+        <tbody>
+    """
+
+    for ext in extlist:
+        name = ext.get("name", "noname")
+        url = ext.get("url", None)
+        description = ext.get("description", "")
+        extension_tags = ext.get("tags", [])
+
+        if url is None:
+            continue
+
+        existing = installed_extension_urls.get(normalize_git_url(url), None)
+        extension_tags = extension_tags + ["installed"] if existing else extension_tags
+
+        if len([x for x in extension_tags if x in tags_to_hide]) > 0:
+            hidden += 1
+            continue
+
+        install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
+
+        tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
+
+        code += f"""
+            <tr>
+                <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
+                <td>{html.escape(description)}</td>
+                <td>{install_code}</td>
+            </tr>
+        
+        """
+
+        for tag in [x for x in extension_tags if x not in tags]:
+            tags[tag] = tag
+
+    code += """
+        </tbody>
+    </table>
+    """
+
+    if hidden > 0:
+        code += f"<p>Extension hidden: {hidden}</p>"
+
+    return code, list(tags)
+
+
+def create_ui():
+    import modules.ui
+
+    with gr.Blocks(analytics_enabled=False) as ui:
+        with gr.Tabs(elem_id="tabs_extensions") as tabs:
+            with gr.TabItem("Installed"):
+
+                with gr.Row():
+                    apply = gr.Button(value="Apply and restart UI", variant="primary")
+                    check = gr.Button(value="Check for updates")
+                    extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
+                    extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
+
+                extensions_table = gr.HTML(lambda: extension_table())
+
+                apply.click(
+                    fn=apply_and_restart,
+                    _js="extensions_apply",
+                    inputs=[extensions_disabled_list, extensions_update_list],
+                    outputs=[],
+                )
+
+                check.click(
+                    fn=check_updates,
+                    _js="extensions_check",
+                    inputs=[],
+                    outputs=[extensions_table],
+                )
+
+            with gr.TabItem("Available"):
+                with gr.Row():
+                    refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
+                    available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
+                    extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
+                    install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
+
+                with gr.Row():
+                    hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
+
+                install_result = gr.HTML()
+                available_extensions_table = gr.HTML()
+
+                refresh_available_extensions_button.click(
+                    fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+                    inputs=[available_extensions_index, hide_tags],
+                    outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
+                )
+
+                install_extension_button.click(
+                    fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
+                    inputs=[extension_to_install, hide_tags],
+                    outputs=[available_extensions_table, extensions_table, install_result],
+                )
+
+                hide_tags.change(
+                    fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
+                    inputs=[hide_tags],
+                    outputs=[available_extensions_table, install_result]
+                )
+
+            with gr.TabItem("Install from URL"):
+                install_url = gr.Text(label="URL for extension's git repository")
+                install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
+                install_button = gr.Button(value="Install", variant="primary")
+                install_result = gr.HTML(elem_id="extension_install_result")
+
+                install_button.click(
+                    fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
+                    inputs=[install_dirname, install_url],
+                    outputs=[extensions_table, install_result],
+                )
+
+    return ui
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
new file mode 100644
index 00000000..21945235
--- /dev/null
+++ b/modules/ui_tempdir.py
@@ -0,0 +1,82 @@
+import os
+import tempfile
+from collections import namedtuple
+from pathlib import Path
+
+import gradio as gr
+
+from PIL import PngImagePlugin
+
+from modules import shared
+
+
+Savedfile = namedtuple("Savedfile", ["name"])
+
+
+def register_tmp_file(gradio, filename):
+    if hasattr(gradio, 'temp_file_sets'):  # gradio 3.15
+        gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
+
+    if hasattr(gradio, 'temp_dirs'):  # gradio 3.9
+        gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
+
+
+def check_tmp_file(gradio, filename):
+    if hasattr(gradio, 'temp_file_sets'):
+        return any([filename in fileset for fileset in gradio.temp_file_sets])
+
+    if hasattr(gradio, 'temp_dirs'):
+        return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
+
+    return False
+
+
+def save_pil_to_file(pil_image, dir=None):
+    already_saved_as = getattr(pil_image, 'already_saved_as', None)
+    if already_saved_as and os.path.isfile(already_saved_as):
+        register_tmp_file(shared.demo, already_saved_as)
+
+        file_obj = Savedfile(already_saved_as)
+        return file_obj
+
+    if shared.opts.temp_dir != "":
+        dir = shared.opts.temp_dir
+
+    use_metadata = False
+    metadata = PngImagePlugin.PngInfo()
+    for key, value in pil_image.info.items():
+        if isinstance(key, str) and isinstance(value, str):
+            metadata.add_text(key, value)
+            use_metadata = True
+
+    file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
+    pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
+    return file_obj
+
+
+# override save to file function so that it also writes PNG info
+gr.processing_utils.save_pil_to_file = save_pil_to_file
+
+
+def on_tmpdir_changed():
+    if shared.opts.temp_dir == "" or shared.demo is None:
+        return
+
+    os.makedirs(shared.opts.temp_dir, exist_ok=True)
+
+    register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
+
+
+def cleanup_tmpdr():
+    temp_dir = shared.opts.temp_dir
+    if temp_dir == "" or not os.path.isdir(temp_dir):
+        return
+
+    for root, dirs, files in os.walk(temp_dir, topdown=False):
+        for name in files:
+            _, extension = os.path.splitext(name)
+            if extension != ".png":
+                continue
+
+            filename = os.path.join(root, name)
+            os.remove(filename)
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 6ab2fb40..231680cb 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -10,6 +10,7 @@ import modules.shared
 from modules import modelloader, shared
 
 LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
 from modules.paths import models_path
 
 
@@ -52,14 +53,22 @@ class Upscaler:
     def do_upscale(self, img: PIL.Image, selected_model: str):
         return img
 
-    def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
+    def upscale(self, img: PIL.Image, scale, selected_model: str = None):
         self.scale = scale
-        dest_w = img.width * scale
-        dest_h = img.height * scale
+        dest_w = int(img.width * scale)
+        dest_h = int(img.height * scale)
+
         for i in range(3):
+            shape = (img.width, img.height)
+
+            img = self.do_upscale(img, selected_model)
+
+            if shape == (img.width, img.height):
+                break
+
             if img.width >= dest_w and img.height >= dest_h:
                 break
-            img = self.do_upscale(img, selected_model)
+
         if img.width != dest_w or img.height != dest_h:
             img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
 
@@ -120,3 +129,17 @@ class UpscalerLanczos(Upscaler):
         self.name = "Lanczos"
         self.scalers = [UpscalerData("Lanczos", None, self)]
 
+
+class UpscalerNearest(Upscaler):
+    scalers = []
+
+    def do_upscale(self, img, selected_model=None):
+        return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
+
+    def load_model(self, _):
+        pass
+
+    def __init__(self, dirname=None):
+        super().__init__(False)
+        self.name = "Nearest"
+        self.scalers = [UpscalerData("Nearest", None, self)]
\ No newline at end of file
diff --git a/modules/xlmr.py b/modules/xlmr.py
new file mode 100644
index 00000000..beab3fdf
--- /dev/null
+++ b/modules/xlmr.py
@@ -0,0 +1,137 @@
+from transformers import BertPreTrainedModel,BertModel,BertConfig
+import torch.nn as nn
+import torch
+from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
+from transformers import XLMRobertaModel,XLMRobertaTokenizer
+from typing import Optional
+
+class BertSeriesConfig(BertConfig):
+    def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
+
+        super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
+        self.project_dim = project_dim
+        self.pooler_fn = pooler_fn
+        self.learn_encoder = learn_encoder
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+    def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
+        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+        self.project_dim = project_dim
+        self.pooler_fn = pooler_fn
+        self.learn_encoder = learn_encoder
+
+
+class BertSeriesModelWithTransformation(BertPreTrainedModel):
+
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+    config_class = BertSeriesConfig
+
+    def __init__(self, config=None, **kargs):
+        # modify initialization for autoloading 
+        if config is None:
+            config = XLMRobertaConfig()
+            config.attention_probs_dropout_prob= 0.1
+            config.bos_token_id=0
+            config.eos_token_id=2
+            config.hidden_act='gelu'
+            config.hidden_dropout_prob=0.1
+            config.hidden_size=1024
+            config.initializer_range=0.02
+            config.intermediate_size=4096
+            config.layer_norm_eps=1e-05
+            config.max_position_embeddings=514
+
+            config.num_attention_heads=16
+            config.num_hidden_layers=24
+            config.output_past=True
+            config.pad_token_id=1
+            config.position_embedding_type= "absolute"
+
+            config.type_vocab_size= 1
+            config.use_cache=True
+            config.vocab_size= 250002
+            config.project_dim = 768
+            config.learn_encoder = False
+        super().__init__(config)
+        self.roberta = XLMRobertaModel(config)
+        self.transformation = nn.Linear(config.hidden_size,config.project_dim)
+        self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
+        self.pooler = lambda x: x[:,0]
+        self.post_init()
+
+    def encode(self,c):
+        device = next(self.parameters()).device
+        text = self.tokenizer(c,
+                        truncation=True,
+                        max_length=77,
+                        return_length=False,
+                        return_overflowing_tokens=False,
+                        padding="max_length",
+                        return_tensors="pt")
+        text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
+        text["attention_mask"] = torch.tensor(
+            text['attention_mask']).to(device)
+        features = self(**text)
+        return features['projection_state'] 
+
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        token_type_ids: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+    ) :
+        r"""
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+
+        outputs = self.roberta(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=True,
+            return_dict=return_dict,
+        )
+
+        # last module outputs
+        sequence_output = outputs[0]
+
+
+        # project every module
+        sequence_output_ln = self.pre_LN(sequence_output)
+
+        # pooler
+        pooler_output = self.pooler(sequence_output_ln)
+        pooler_output = self.transformation(pooler_output)
+        projection_state = self.transformation(outputs.last_hidden_state)
+
+        return {
+            'pooler_output':pooler_output,
+            'last_hidden_state':outputs.last_hidden_state,
+            'hidden_states':outputs.hidden_states,
+            'attentions':outputs.attentions,
+            'projection_state':projection_state,
+            'sequence_out': sequence_output
+        }
+
+
+class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
+    base_model_prefix = 'roberta'
+    config_class= RobertaSeriesConfig
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index da1969cf..4f09385f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,16 +1,19 @@
+blendmodes
+accelerate
 basicsr
-diffusers
 fairscale==0.4.4
 fonts
 font-roboto
 gfpgan
-gradio==3.5
+gradio==3.15.0
 invisible-watermark
 numpy
 omegaconf
+opencv-contrib-python
+requests
 piexif
 Pillow
-pytorch_lightning
+pytorch_lightning==1.7.7
 realesrgan
 scikit-image>=0.19
 timm==0.4.12
@@ -24,3 +27,7 @@ torchdiffeq
 kornia
 lark
 inflection
+GitPython
+torchsde
+safetensors
+psutil; sys_platform == 'darwin'
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 72ccc5a3..975102d9 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -1,8 +1,9 @@
+blendmodes==2022
 transformers==4.19.2
-diffusers==0.3.0
+accelerate==0.12.0
 basicsr==1.4.2
 gfpgan==1.3.8
-gradio==3.5
+gradio==3.15.0
 numpy==1.23.3
 Pillow==9.2.0
 realesrgan==0.3.0
@@ -23,3 +24,7 @@ torchdiffeq==0.2.3
 kornia==0.6.7
 lark==1.1.2
 inflection==0.5.1
+GitPython==3.1.27
+torchsde==0.2.5
+safetensors==0.2.7
+httpcore<=0.15
diff --git a/script.js b/script.js
index 8b3b67e3..0e117d06 100644
--- a/script.js
+++ b/script.js
@@ -1,9 +1,10 @@
-function gradioApp(){
-    return document.getElementsByTagName('gradio-app')[0].shadowRoot;
+function gradioApp() {
+    const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot
+    return !!gradioShadowRoot ? gradioShadowRoot : document;
 }
 
 function get_uiCurrentTab() {
-    return gradioApp().querySelector('.tabs button:not(.border-transparent)')
+    return gradioApp().querySelector('#tabs button:not(.border-transparent)')
 }
 
 function get_uiCurrentTabContent() {
@@ -82,4 +83,4 @@ function uiElementIsVisible(el) {
         }
     }
     return isVisible;
-}
\ No newline at end of file
+}
diff --git a/scripts/custom_code.py b/scripts/custom_code.py
index a9b10c09..22e7b77a 100644
--- a/scripts/custom_code.py
+++ b/scripts/custom_code.py
@@ -14,7 +14,7 @@ class Script(scripts.Script):
         return cmd_opts.allow_code
 
     def ui(self, is_img2img):
-        code = gr.Textbox(label="Python code", visible=False, lines=1)
+        code = gr.Textbox(label="Python code", lines=1)
 
         return [code]
 
diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py
index d438175c..1229f61b 100644
--- a/scripts/img2imgalt.py
+++ b/scripts/img2imgalt.py
@@ -34,6 +34,9 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
         sigma_in = torch.cat([sigmas[i] * s_in] * 2)
         cond_in = torch.cat([uncond, cond])
 
+        image_conditioning = torch.cat([p.image_conditioning] * 2)
+        cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
+
         c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
         t = dnw.sigma_to_t(sigma_in)
 
@@ -78,6 +81,9 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
         sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
         cond_in = torch.cat([uncond, cond])
 
+        image_conditioning = torch.cat([p.image_conditioning] * 2)
+        cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
+
         c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
 
         if i == 1:
@@ -151,7 +157,7 @@ class Script(scripts.Script):
     def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
         # Override
         if override_sampler:
-            p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler")
+            p.sampler_name = "Euler"
         if override_prompt:
             p.prompt = original_prompt
             p.negative_prompt = original_negative_prompt
@@ -160,8 +166,7 @@ class Script(scripts.Script):
         if override_strength:
             p.denoising_strength = 1.0
 
-
-        def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+        def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
             lat = (p.init_latent.cpu().numpy() * 10).astype(int)
 
             same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
@@ -186,7 +191,7 @@ class Script(scripts.Script):
             
             combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
             
-            sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model)
+            sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
 
             sigmas = sampler.model_wrap.get_sigmas(p.steps)
             
@@ -194,7 +199,7 @@ class Script(scripts.Script):
             
             p.seed = p.seed + 1
             
-            return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning)
+            return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
 
         p.sample = sample_extra
 
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py
index a6468e09..cf71cb92 100644
--- a/scripts/outpainting_mk_2.py
+++ b/scripts/outpainting_mk_2.py
@@ -132,7 +132,7 @@ class Script(scripts.Script):
         info = gr.HTML("<p style=\"margin-bottom:0.75em\">Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8</p>")
 
         pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128)
-        mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, visible=False)
+        mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8)
         direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'])
         noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0)
         color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05)
@@ -172,54 +172,54 @@ class Script(scripts.Script):
         if down > 0:
             down = target_h - init_img.height - up
 
-        init_image = p.init_images[0]
-
-        state.job_count = (1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)
-
-        def expand(init, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
+        def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
             is_horiz = is_left or is_right
             is_vert = is_top or is_bottom
             pixels_horiz = expand_pixels if is_horiz else 0
             pixels_vert = expand_pixels if is_vert else 0
 
-            res_w = init.width + pixels_horiz
-            res_h = init.height + pixels_vert
-            process_res_w = math.ceil(res_w / 64) * 64
-            process_res_h = math.ceil(res_h / 64) * 64
+            images_to_process = []
+            output_images = []
+            for n in range(count):
+                res_w = init[n].width + pixels_horiz
+                res_h = init[n].height + pixels_vert
+                process_res_w = math.ceil(res_w / 64) * 64
+                process_res_h = math.ceil(res_h / 64) * 64
 
-            img = Image.new("RGB", (process_res_w, process_res_h))
-            img.paste(init, (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
-            mask = Image.new("RGB", (process_res_w, process_res_h), "white")
-            draw = ImageDraw.Draw(mask)
-            draw.rectangle((
-                expand_pixels + mask_blur if is_left else 0,
-                expand_pixels + mask_blur if is_top else 0,
-                mask.width - expand_pixels - mask_blur if is_right else res_w,
-                mask.height - expand_pixels - mask_blur if is_bottom else res_h,
-            ), fill="black")
+                img = Image.new("RGB", (process_res_w, process_res_h))
+                img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
+                mask = Image.new("RGB", (process_res_w, process_res_h), "white")
+                draw = ImageDraw.Draw(mask)
+                draw.rectangle((
+                    expand_pixels + mask_blur if is_left else 0,
+                    expand_pixels + mask_blur if is_top else 0,
+                    mask.width - expand_pixels - mask_blur if is_right else res_w,
+                    mask.height - expand_pixels - mask_blur if is_bottom else res_h,
+                ), fill="black")
 
-            np_image = (np.asarray(img) / 255.0).astype(np.float64)
-            np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
-            noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
-            out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
+                np_image = (np.asarray(img) / 255.0).astype(np.float64)
+                np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
+                noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
+                output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
 
-            target_width = min(process_width, init.width + pixels_horiz) if is_horiz else img.width
-            target_height = min(process_height, init.height + pixels_vert) if is_vert else img.height
+                target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
+                target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
+                p.width = target_width if is_horiz else img.width
+                p.height = target_height if is_vert else img.height
 
-            crop_region = (
-                0 if is_left else out.width - target_width,
-                0 if is_top else out.height - target_height,
-                target_width if is_left else out.width,
-                target_height if is_top else out.height,
-            )
+                crop_region = (
+                    0 if is_left else output_images[n].width - target_width,
+                    0 if is_top else output_images[n].height - target_height,
+                    target_width if is_left else output_images[n].width,
+                    target_height if is_top else output_images[n].height,
+                )
+                mask = mask.crop(crop_region)
+                p.image_mask = mask
 
-            image_to_process = out.crop(crop_region)
-            mask = mask.crop(crop_region)
+                image_to_process = output_images[n].crop(crop_region)
+                images_to_process.append(image_to_process)
 
-            p.width = target_width if is_horiz else img.width
-            p.height = target_height if is_vert else img.height
-            p.init_images = [image_to_process]
-            p.image_mask = mask
+            p.init_images = images_to_process
 
             latent_mask = Image.new("RGB", (p.width, p.height), "white")
             draw = ImageDraw.Draw(latent_mask)
@@ -232,31 +232,52 @@ class Script(scripts.Script):
             p.latent_mask = latent_mask
 
             proc = process_images(p)
-            proc_img = proc.images[0]
 
             if initial_seed_and_info[0] is None:
                 initial_seed_and_info[0] = proc.seed
                 initial_seed_and_info[1] = proc.info
 
-            out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height))
-            out = out.crop((0, 0, res_w, res_h))
-            return out
+            for n in range(count):
+                output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
+                output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
 
-        img = init_image
+            return output_images
 
-        if left > 0:
-            img = expand(img, left, is_left=True)
-        if right > 0:
-            img = expand(img, right, is_right=True)
-        if up > 0:
-            img = expand(img, up, is_top=True)
-        if down > 0:
-            img = expand(img, down, is_bottom=True)
+        batch_count = p.n_iter
+        batch_size = p.batch_size
+        p.n_iter = 1
+        state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
+        all_processed_images = []
 
-        res = Processed(p, [img], initial_seed_and_info[0], initial_seed_and_info[1])
+        for i in range(batch_count):
+            imgs = [init_img] * batch_size
+            state.job = f"Batch {i + 1} out of {batch_count}"
+
+            if left > 0:
+                imgs = expand(imgs, batch_size, left, is_left=True)
+            if right > 0:
+                imgs = expand(imgs, batch_size, right, is_right=True)
+            if up > 0:
+                imgs = expand(imgs, batch_size, up, is_top=True)
+            if down > 0:
+                imgs = expand(imgs, batch_size, down, is_bottom=True)
+
+            all_processed_images += imgs
+
+        all_images = all_processed_images
+
+        combined_grid_image = images.image_grid(all_processed_images)
+        unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
+        if opts.return_grid and not unwanted_grid_because_of_img_count:
+            all_images = [combined_grid_image] + all_processed_images
+
+        res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
 
         if opts.samples_save:
-            images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
+            for img in all_processed_images:
+                images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
+
+        if opts.grid_save and not unwanted_grid_because_of_img_count:
+            images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
 
         return res
-
diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py
index b0469110..ea45beb0 100644
--- a/scripts/poor_mans_outpainting.py
+++ b/scripts/poor_mans_outpainting.py
@@ -22,8 +22,8 @@ class Script(scripts.Script):
             return None
 
         pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128)
-        mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False)
-        inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False)
+        mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
+        inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index")
         direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'])
 
         return [pixels, mask_blur, inpainting_fill, direction]
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
index e49c9b20..4c79eaef 100644
--- a/scripts/prompt_matrix.py
+++ b/scripts/prompt_matrix.py
@@ -18,7 +18,7 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
     ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
     hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
 
-    first_pocessed = None
+    first_processed = None
 
     state.job_count = len(xs) * len(ys)
 
@@ -27,17 +27,17 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
             state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
 
             processed = cell(x, y)
-            if first_pocessed is None:
-                first_pocessed = processed
+            if first_processed is None:
+                first_processed = processed
 
             res.append(processed.images[0])
 
     grid = images.image_grid(res, rows=len(ys))
     grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
 
-    first_pocessed.images = [grid]
+    first_processed.images = [grid]
 
-    return first_pocessed
+    return first_processed
 
 
 class Script(scripts.Script):
@@ -46,10 +46,11 @@ class Script(scripts.Script):
 
     def ui(self, is_img2img):
         put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False)
+        different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False)
 
-        return [put_at_start]
+        return [put_at_start, different_seeds]
 
-    def run(self, p, put_at_start):
+    def run(self, p, put_at_start, different_seeds):
         modules.processing.fix_seed(p)
 
         original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
@@ -73,15 +74,17 @@ class Script(scripts.Script):
         print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
 
         p.prompt = all_prompts
-        p.seed = [p.seed for _ in all_prompts]
+        p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
         p.prompt_for_display = original_prompt
         processed = process_images(p)
 
         grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
         grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
         processed.images.insert(0, grid)
+        processed.index_of_first_image = 1
+        processed.infotexts.insert(0, processed.infotexts[0])
 
         if opts.grid_save:
-            images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, grid=True, p=p)
+            images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p)
 
         return processed
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index 1266be6f..e8386ed2 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -1,6 +1,7 @@
 import copy
 import math
 import os
+import random
 import sys
 import traceback
 import shlex
@@ -8,6 +9,7 @@ import shlex
 import modules.scripts as scripts
 import gradio as gr
 
+from modules import sd_samplers
 from modules.processing import Processed, process_images
 from PIL import Image
 from modules.shared import opts, cmd_opts, state
@@ -43,6 +45,7 @@ prompt_tags = {
     "seed_resize_from_h": process_int_tag,
     "seed_resize_from_w": process_int_tag,
     "sampler_index": process_int_tag,
+    "sampler_name": process_string_tag,
     "batch_size": process_int_tag,
     "n_iter": process_int_tag,
     "steps": process_int_tag,
@@ -65,14 +68,28 @@ def cmdargs(line):
         arg = args[pos]
 
         assert arg.startswith("--"), f'must start with "--": {arg}'
+        assert pos+1 < len(args), f'missing argument for command line option {arg}'
+
         tag = arg[2:]
 
+        if tag == "prompt" or tag == "negative_prompt":
+            pos += 1
+            prompt = args[pos]
+            pos += 1
+            while pos < len(args) and not args[pos].startswith("--"):
+                prompt += " "
+                prompt += args[pos]
+                pos += 1
+            res[tag] = prompt
+            continue
+
+
         func = prompt_tags.get(tag, None)
         assert func, f'unknown commandline option: {arg}'
 
-        assert pos+1 < len(args), f'missing argument for command line option {arg}'
-
         val = args[pos+1]
+        if tag == "sampler_name":
+            val = sd_samplers.samplers_map.get(val.lower(), None)
 
         res[tag] = func(val)
 
@@ -81,32 +98,36 @@ def cmdargs(line):
     return res
 
 
+def load_prompt_file(file):
+    if file is None:
+        lines = []
+    else:
+        lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
+
+    return None, "\n".join(lines), gr.update(lines=7)
+
+
 class Script(scripts.Script):
     def title(self):
         return "Prompts from file or textbox"
 
     def ui(self, is_img2img):
-        # This checkbox would look nicer as two tabs, but there are two problems:
-        # 1) There is a bug in Gradio 3.3 that prevents visibility from working on Tabs
-        # 2) Even with Gradio 3.3.1, returning a control (like Tabs) that can't be used as input
-        #    causes a AttributeError: 'Tabs' object has no attribute 'preprocess' assert,
-        #    due to the way Script assumes all controls returned can be used as inputs.
-        # Therefore, there's no good way to use grouping components right now,
-        # so we will use a checkbox! :)
-        checkbox_txt = gr.Checkbox(label="Show Textbox", value=False)
-        file = gr.File(label="File with inputs", type='bytes')
-        prompt_txt = gr.TextArea(label="Prompts")
-        checkbox_txt.change(fn=lambda x: [gr.File.update(visible = not x), gr.TextArea.update(visible = x)], inputs=[checkbox_txt], outputs=[file, prompt_txt])
-        return [checkbox_txt, file, prompt_txt]
+        checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False)
+        checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False)
 
-    def on_show(self, checkbox_txt, file, prompt_txt):
-        return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ]
+        prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1)
+        file = gr.File(label="Upload prompt inputs", type='bytes')
 
-    def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
-        if checkbox_txt:
-            lines = [x.strip() for x in prompt_txt.splitlines()]
-        else:
-            lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")]
+        file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])
+
+        # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n.
+        # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
+        # be unclear to the user that shift-enter is needed.
+        prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
+        return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
+
+    def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
+        lines = [x.strip() for x in prompt_txt.splitlines()]
         lines = [x for x in lines if len(x) > 0]
 
         p.do_not_save_grid = True
@@ -119,7 +140,7 @@ class Script(scripts.Script):
                 try:
                     args = cmdargs(line)
                 except Exception:
-                    print(f"Error parsing line [line] as commandline:", file=sys.stderr)
+                    print(f"Error parsing line {line} as commandline:", file=sys.stderr)
                     print(traceback.format_exc(), file=sys.stderr)
                     args = {"prompt": line}
             else:
@@ -134,9 +155,14 @@ class Script(scripts.Script):
             jobs.append(args)
 
         print(f"Will process {len(lines)} lines in {job_count} jobs.")
+        if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1:
+            p.seed = int(random.randrange(4294967294))
+
         state.job_count = job_count
 
         images = []
+        all_prompts = []
+        infotexts = []
         for n, args in enumerate(jobs):
             state.job = f"{state.job_no + 1} out of {state.job_count}"
 
@@ -146,5 +172,10 @@ class Script(scripts.Script):
 
             proc = process_images(copy_p)
             images += proc.images
+            
+            if checkbox_iterate:
+                p.seed = p.seed + (p.batch_size * p.n_iter)
+            all_prompts += proc.all_prompts
+            infotexts += proc.infotexts
 
-        return Processed(p, images, p.seed, "")
+        return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
index cb37ff7e..9739545c 100644
--- a/scripts/sd_upscale.py
+++ b/scripts/sd_upscale.py
@@ -17,13 +17,14 @@ class Script(scripts.Script):
         return is_img2img
 
     def ui(self, is_img2img):
-        info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image to twice the dimensions; use width and height sliders to set tile size</p>")
-        overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
-        upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", visible=False)
+        info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>")
+        overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64)
+        scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0)
+        upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
 
-        return [info, overlap, upscaler_index]
+        return [info, overlap, upscaler_index, scale_factor]
 
-    def run(self, p, _, overlap, upscaler_index):
+    def run(self, p, _, overlap, upscaler_index, scale_factor):
         processing.fix_seed(p)
         upscaler = shared.sd_upscalers[upscaler_index]
 
@@ -34,9 +35,10 @@ class Script(scripts.Script):
         seed = p.seed
 
         init_img = p.init_images[0]
-        
-        if(upscaler.name != "None"): 
-            img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
+        init_img = images.flatten(init_img, opts.img2img_background_color)
+
+        if upscaler.name != "None":
+            img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
         else:
             img = init_img
 
@@ -69,7 +71,7 @@ class Script(scripts.Script):
             work_results = []
             for i in range(batch_count):
                 p.batch_size = batch_size
-                p.init_images = work[i*batch_size:(i+1)*batch_size]
+                p.init_images = work[i * batch_size:(i + 1) * batch_size]
 
                 state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
                 processed = processing.process_images(p)
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 5cca168a..f92f9776 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -10,13 +10,16 @@ import numpy as np
 import modules.scripts as scripts
 import gradio as gr
 
-from modules import images
+from modules import images, paths, sd_samplers
 from modules.hypernetworks import hypernetwork
-from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img
+from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
 from modules.shared import opts, cmd_opts, state
 import modules.shared as shared
 import modules.sd_samplers
 import modules.sd_models
+import modules.sd_vae
+import glob
+import os
 import re
 
 
@@ -58,29 +61,19 @@ def apply_order(p, x, xs):
         prompt_tmp += part
         prompt_tmp += x[idx]
     p.prompt = prompt_tmp + p.prompt
-    
-
-def build_samplers_dict(p):
-    samplers_dict = {}
-    for i, sampler in enumerate(get_correct_sampler(p)):
-        samplers_dict[sampler.name.lower()] = i
-        for alias in sampler.aliases:
-            samplers_dict[alias.lower()] = i
-    return samplers_dict
 
 
 def apply_sampler(p, x, xs):
-    sampler_index = build_samplers_dict(p).get(x.lower(), None)
-    if sampler_index is None:
+    sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
+    if sampler_name is None:
         raise RuntimeError(f"Unknown sampler: {x}")
 
-    p.sampler_index = sampler_index
+    p.sampler_name = sampler_name
 
 
 def confirm_samplers(p, xs):
-    samplers_dict = build_samplers_dict(p)
     for x in xs:
-        if x.lower() not in samplers_dict.keys():
+        if x.lower() not in sd_samplers.samplers_map:
             raise RuntimeError(f"Unknown sampler: {x}")
 
 
@@ -89,6 +82,7 @@ def apply_checkpoint(p, x, xs):
     if info is None:
         raise RuntimeError(f"Unknown checkpoint: {x}")
     modules.sd_models.reload_model_weights(shared.sd_model, info)
+    p.sd_model = shared.sd_model
 
 
 def confirm_checkpoints(p, xs):
@@ -123,6 +117,38 @@ def apply_clip_skip(p, x, xs):
     opts.data["CLIP_stop_at_last_layers"] = x
 
 
+def apply_upscale_latent_space(p, x, xs):
+    if x.lower().strip() != '0':
+        opts.data["use_scale_latent_for_hires_fix"] = True
+    else:
+        opts.data["use_scale_latent_for_hires_fix"] = False
+
+
+def find_vae(name: str):
+    if name.lower() in ['auto', 'none']:
+        return name
+    else:
+        vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE'))
+        found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True)
+        if found:
+            return found[0]
+        else:
+            return 'auto'
+
+
+def apply_vae(p, x, xs):
+    if x.lower().strip() == 'none':
+        modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None')
+    else:
+        found = find_vae(x)
+        if found:
+            v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found)
+
+
+def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
+    p.styles = x.split(',')
+
+
 def format_value_add_label(p, opt, x):
     if type(x) == float:
         x = round(x, 8)
@@ -152,7 +178,6 @@ def str_permutations(x):
     """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
     return x
 
-
 AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
 AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
 
@@ -177,6 +202,10 @@ axis_options = [
     AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
     AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
     AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
+    AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
+    AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
+    AxisOption("VAE", str, apply_vae, format_value_add_label, None),
+    AxisOption("Styles", str, apply_styles, format_value_add_label, None),
 ]
 
 
@@ -238,9 +267,11 @@ class SharedSettingsStackHelper(object):
         self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
         self.hypernetwork = opts.sd_hypernetwork
         self.model = shared.sd_model
+        self.vae = opts.sd_vae
   
     def __exit__(self, exc_type, exc_value, tb):
         modules.sd_models.reload_model_weights(self.model)
+        modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae))
 
         hypernetwork.load_hypernetwork(self.hypernetwork)
         hypernetwork.apply_strength()
@@ -262,12 +293,12 @@ class Script(scripts.Script):
         current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img]
 
         with gr.Row():
-            x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, visible=False, type="index", elem_id="x_type")
-            x_values = gr.Textbox(label="X values", visible=False, lines=1)
+            x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id="x_type")
+            x_values = gr.Textbox(label="X values", lines=1)
 
         with gr.Row():
-            y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, visible=False, type="index", elem_id="y_type")
-            y_values = gr.Textbox(label="Y values", visible=False, lines=1)
+            y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id="y_type")
+            y_values = gr.Textbox(label="Y values", lines=1)
         
         draw_legend = gr.Checkbox(label='Draw legend', value=True)
         include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
@@ -392,6 +423,6 @@ class Script(scripts.Script):
             )
 
         if opts.grid_save:
-            images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+            images.save_image(processed.images[0], p.outpath_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
 
         return processed
diff --git a/style.css b/style.css
index 26ae36a5..2116ec3c 100644
--- a/style.css
+++ b/style.css
@@ -73,8 +73,9 @@
     margin-right: auto;
 }
 
-#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{
-    min-width: auto;
+[id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{
+    min-width: 2.3em;
+    height: 2.5em;
     flex-grow: 0;
     padding-left: 0.25em;
     padding-right: 0.25em;
@@ -84,27 +85,28 @@
     display: none;
 }
 
-#seed_row, #subseed_row{
+[id$=_seed_row], [id$=_subseed_row]{
     gap: 0.5rem;
+    padding: 0.6em;
 }
 
-#subseed_show_box{
+[id$=_subseed_show_box]{
     min-width: auto;
     flex-grow: 0;
 }
 
-#subseed_show_box > div{
+[id$=_subseed_show_box] > div{
     border: 0;
     height: 100%;
 }
 
-#subseed_show{
+[id$=_subseed_show]{
     min-width: auto;
     flex-grow: 0;
     padding: 0;
 }
 
-#subseed_show label{
+[id$=_subseed_show] label{
     height: 100%;
 }
 
@@ -114,7 +116,7 @@
     padding: 0.4em 0;
 }
 
-#roll, #paste, #style_create, #style_apply{
+#roll_col > button {
     min-width: 2em;
     min-height: 2em;
     max-width: 2em;
@@ -206,24 +208,24 @@ button{
 
 fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500,  label.block span{
     position: absolute;
-    top: -0.6em;
+    top: -0.7em;
     line-height: 1.2em;
-    padding: 0 0.5em;
-    margin: 0;
+    padding: 0;
+    margin: 0 0.5em;
 
     background-color: white;
-    border-top: 1px solid #eee;
-    border-left: 1px solid #eee;
-    border-right: 1px solid #eee;
+    box-shadow:  6px 0 6px 0px white, -6px 0 6px 0px white;
 
     z-index: 300;
 }
 
 .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
     background-color: rgb(31, 41, 55);
-    border-top: 1px solid rgb(55 65 81);
-    border-left: 1px solid rgb(55 65 81);
-    border-right: 1px solid rgb(55 65 81);
+    box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55);
+}
+
+#txt2img_column_batch, #img2img_column_batch{
+    min-width: min(13.5em, 100%) !important;
 }
 
 #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
@@ -232,22 +234,40 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500,  label.block s
     margin-right: 8em;
 }
 
-.gr-panel div.flex-col div.justify-between label span{
-    margin: 0;
-}
-
 #settings .gr-panel div.flex-col div.justify-between div{
     position: relative;
     z-index: 200;
 }
 
-input[type="range"]{
-    margin: 0.5em 0 -0.3em 0;
+#settings{
+    display: block;
 }
 
-#txt2img_sampling label{
-    padding-left: 0.6em;
-    padding-right: 0.6em;
+#settings > div{
+    border: none;
+    margin-left: 10em;
+}
+
+#settings > div.flex-wrap{
+    float: left;
+    display: block;
+    margin-left: 0;
+    width: 10em;
+}
+
+#settings > div.flex-wrap button{
+    display: block;
+    border: none;
+    text-align: left;
+}
+
+#settings_result{
+    height: 1.4em;
+    margin: 0 1.2em;
+}
+
+input[type="range"]{
+    margin: 0.5em 0 -0.3em 0;
 }
 
 #mask_bug_info {
@@ -260,6 +280,16 @@ input[type="range"]{
 #txt2img_negative_prompt, #img2img_negative_prompt{
 }
 
+/* gradio 3.8 adds opacity to progressbar which makes it blink; disable it here */
+.transition.opacity-20 {
+  opacity: 1 !important;
+}
+
+/* more gradio's garbage cleanup */
+.min-h-\[4rem\] {
+  min-height: unset !important;
+}
+
 #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
   position: absolute;
   z-index: 1000;
@@ -314,8 +344,8 @@ input[type="range"]{
 
 .modalControls {
     display: grid;
-    grid-template-columns: 32px auto 1fr 32px;
-    grid-template-areas: "zoom tile space close";
+    grid-template-columns: 32px 32px 32px 1fr 32px;
+    grid-template-areas: "zoom tile save space close";
     position: absolute;
     top: 0;
     left: 0;
@@ -333,6 +363,10 @@ input[type="range"]{
     grid-area: zoom;
 }
 
+.modalSave {
+    grid-area: save;
+}
+
 .modalTileImage {
     grid-area: tile;
 }
@@ -346,8 +380,18 @@ input[type="range"]{
   cursor: pointer;
 }
 
+.modalSave {
+    color: white;
+    font-size: 28px;
+    margin-top: 8px;
+    font-weight: bold;
+    cursor: pointer;
+}
+
 .modalClose:hover,
 .modalClose:focus,
+.modalSave:hover,
+.modalSave:focus,
 .modalZoom:hover,
 .modalZoom:focus {
   color: #999;
@@ -477,13 +521,6 @@ input[type="range"]{
     padding: 0;
 }
 
-#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
-    max-width: 2.5em;
-    min-width: 2.5em;
-    height: 2.4em;
-}
-
-
 canvas[key="mask"] {
     z-index: 12 !important;
     filter: invert();
@@ -497,7 +534,7 @@ canvas[key="mask"] {
     position: absolute;
     right: 0.5em;
     top: -0.6em;
-    z-index: 200;
+    z-index: 400;
     width: 8em;
 }
 #quicksettings .gr-box > div > div > input.gr-text-input {
@@ -515,3 +552,157 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
     max-height: 480px !important;
     min-height: 480px !important;
 }
+
+/* Extensions */
+
+#tab_extensions table{
+    border-collapse: collapse;
+}
+
+#tab_extensions table td, #tab_extensions table th{
+    border: 1px solid #ccc;
+    padding: 0.25em 0.5em;
+}
+
+#tab_extensions table input[type="checkbox"]{
+    margin-right: 0.5em;
+}
+
+#tab_extensions button{
+    max-width: 16em;
+}
+
+#tab_extensions input[disabled="disabled"]{
+    opacity: 0.5;
+}
+
+.extension-tag{
+    font-weight: bold;
+    font-size: 95%;
+}
+
+#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{
+    min-width: auto;
+    padding-left: 0.5em;
+    padding-right: 0.5em;
+}
+
+.gr-form{
+    background-color: white;
+}
+
+.dark .gr-form{
+    background-color: rgb(31 41 55 / var(--tw-bg-opacity));
+}
+
+.gr-button-tool{
+    max-width: 2.5em;
+    min-width: 2.5em !important;
+    height: 2.4em;
+    margin: 0.55em 0;
+}
+
+#quicksettings .gr-button-tool{
+    margin: 0;
+}
+
+
+#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form {
+    padding-top: 0.9em;
+}
+
+#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form{
+    border: none;
+    padding-bottom: 0.5em;
+}
+
+footer {
+    display: none !important;
+}
+
+#footer{
+    text-align: center;
+}
+
+#footer div{
+    display: inline-block;
+}
+
+/* The following handles localization for right-to-left (RTL) languages like Arabic.
+The rtl media type will only be activated by the logic in javascript/localization.js.
+If you change anything above, you need to make sure it is RTL compliant by just running
+your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/.
+Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/
+@media rtl {
+    /* this part was added manually */
+    :host {
+        direction: rtl;
+    }
+    select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress {
+        direction: ltr;
+    }
+    #script_list > label > select,
+    #x_type > label > select,
+    #y_type > label > select {
+        direction: rtl;
+    }
+    .gr-radio, .gr-checkbox{
+        margin-left: 0.25em;
+    }
+
+    /* automatically generated with few manual modifications */
+    .performance .time {
+        margin-right: unset;
+        margin-left: 0;
+    }
+    .justify-center.overflow-x-scroll {
+        justify-content: right;
+    }
+    .justify-center.overflow-x-scroll button:first-of-type {
+        margin-left: unset;
+        margin-right: auto;
+    }
+    .justify-center.overflow-x-scroll button:last-of-type {
+        margin-right: unset;
+        margin-left: auto;
+    }
+    #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
+        margin-right: unset;
+        margin-left: 8em;
+    }
+    #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
+        right: unset;
+        left: 0;
+    }
+    .progressDiv .progress{
+        padding: 0 0 0 8px;
+        text-align: left;
+    }
+    #lightboxModal{
+        left: unset;
+        right: 0;
+    }
+    .modalPrev, .modalNext{
+        border-radius: 3px 0 0 3px;
+    }
+    .modalNext {
+        right: unset;
+        left: 0;
+        border-radius: 0 3px 3px 0;
+    }
+    #imageARPreview{
+        left:unset;
+        right:0px;
+    }
+    #txt2img_skip, #img2img_skip{
+        right: unset;
+        left: 0px;
+    }
+    #context-menu{
+        box-shadow:-1px 1px 2px #CE6400;
+    }
+    .gr-box > div > div > input.gr-text-input{
+        right: unset;
+        left: 0.5em;
+    }
+}
\ No newline at end of file
diff --git a/test/__init__.py b/test/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/advanced_features/__init__.py b/test/advanced_features/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/advanced_features/extras_test.py b/test/advanced_features/extras_test.py
new file mode 100644
index 00000000..8763f8ed
--- /dev/null
+++ b/test/advanced_features/extras_test.py
@@ -0,0 +1,29 @@
+import unittest
+
+
+class TestExtrasWorking(unittest.TestCase):
+    def setUp(self):
+        self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image"
+        self.simple_extras = {
+            "resize_mode": 0,
+            "show_extras_results": True,
+            "gfpgan_visibility": 0,
+            "codeformer_visibility": 0,
+            "codeformer_weight": 0,
+            "upscaling_resize": 2,
+            "upscaling_resize_w": 128,
+            "upscaling_resize_h": 128,
+            "upscaling_crop": True,
+            "upscaler_1": "None",
+            "upscaler_2": "None",
+            "extras_upscaler_2_visibility": 0,
+            "image": ""
+            }
+
+
+class TestExtrasCorrectness(unittest.TestCase):
+    pass
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/test/advanced_features/txt2img_test.py b/test/advanced_features/txt2img_test.py
new file mode 100644
index 00000000..36ed7b9a
--- /dev/null
+++ b/test/advanced_features/txt2img_test.py
@@ -0,0 +1,47 @@
+import unittest
+import requests
+
+
+class TestTxt2ImgWorking(unittest.TestCase):
+    def setUp(self):
+        self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
+        self.simple_txt2img = {
+            "enable_hr": False,
+            "denoising_strength": 0,
+            "firstphase_width": 0,
+            "firstphase_height": 0,
+            "prompt": "example prompt",
+            "styles": [],
+            "seed": -1,
+            "subseed": -1,
+            "subseed_strength": 0,
+            "seed_resize_from_h": -1,
+            "seed_resize_from_w": -1,
+            "batch_size": 1,
+            "n_iter": 1,
+            "steps": 3,
+            "cfg_scale": 7,
+            "width": 64,
+            "height": 64,
+            "restore_faces": False,
+            "tiling": False,
+            "negative_prompt": "",
+            "eta": 0,
+            "s_churn": 0,
+            "s_tmax": 0,
+            "s_tmin": 0,
+            "s_noise": 1,
+            "sampler_index": "Euler a"
+        }
+
+    def test_txt2img_with_restore_faces_performed(self):
+        self.simple_txt2img["restore_faces"] = True
+        self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+
+
+class TestTxt2ImgCorrectness(unittest.TestCase):
+    pass
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/test/basic_features/__init__.py b/test/basic_features/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py
new file mode 100644
index 00000000..0a9c1e8a
--- /dev/null
+++ b/test/basic_features/img2img_test.py
@@ -0,0 +1,55 @@
+import unittest
+import requests
+from gradio.processing_utils import encode_pil_to_base64
+from PIL import Image
+
+
+class TestImg2ImgWorking(unittest.TestCase):
+    def setUp(self):
+        self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
+        self.simple_img2img = {
+            "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))],
+            "resize_mode": 0,
+            "denoising_strength": 0.75,
+            "mask": None,
+            "mask_blur": 4,
+            "inpainting_fill": 0,
+            "inpaint_full_res": False,
+            "inpaint_full_res_padding": 0,
+            "inpainting_mask_invert": 0,
+            "prompt": "example prompt",
+            "styles": [],
+            "seed": -1,
+            "subseed": -1,
+            "subseed_strength": 0,
+            "seed_resize_from_h": -1,
+            "seed_resize_from_w": -1,
+            "batch_size": 1,
+            "n_iter": 1,
+            "steps": 3,
+            "cfg_scale": 7,
+            "width": 64,
+            "height": 64,
+            "restore_faces": False,
+            "tiling": False,
+            "negative_prompt": "",
+            "eta": 0,
+            "s_churn": 0,
+            "s_tmax": 0,
+            "s_tmin": 0,
+            "s_noise": 1,
+            "override_settings": {},
+            "sampler_index": "Euler a",
+            "include_init_images": False
+            }
+
+    def test_img2img_simple_performed(self):
+        self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
+
+    def test_inpainting_masked_performed(self):
+        self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
+        self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py
new file mode 100644
index 00000000..1c2674b2
--- /dev/null
+++ b/test/basic_features/txt2img_test.py
@@ -0,0 +1,68 @@
+import unittest
+import requests
+
+
+class TestTxt2ImgWorking(unittest.TestCase):
+    def setUp(self):
+        self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
+        self.simple_txt2img = {
+            "enable_hr": False,
+            "denoising_strength": 0,
+            "firstphase_width": 0,
+            "firstphase_height": 0,
+            "prompt": "example prompt",
+            "styles": [],
+            "seed": -1,
+            "subseed": -1,
+            "subseed_strength": 0,
+            "seed_resize_from_h": -1,
+            "seed_resize_from_w": -1,
+            "batch_size": 1,
+            "n_iter": 1,
+            "steps": 3,
+            "cfg_scale": 7,
+            "width": 64,
+            "height": 64,
+            "restore_faces": False,
+            "tiling": False,
+            "negative_prompt": "",
+            "eta": 0,
+            "s_churn": 0,
+            "s_tmax": 0,
+            "s_tmin": 0,
+            "s_noise": 1,
+            "sampler_index": "Euler a"
+        }
+
+    def test_txt2img_simple_performed(self):
+        self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+
+    def test_txt2img_with_negative_prompt_performed(self):
+        self.simple_txt2img["negative_prompt"] = "example negative prompt"
+        self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+
+    def test_txt2img_not_square_image_performed(self):
+        self.simple_txt2img["height"] = 128
+        self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+
+    def test_txt2img_with_hrfix_performed(self):
+        self.simple_txt2img["enable_hr"] = True
+        self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+
+    def test_txt2img_with_tiling_performed(self):
+        self.simple_txt2img["tiling"] = True
+        self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+
+    def test_txt2img_with_vanilla_sampler_performed(self):
+        self.simple_txt2img["sampler_index"] = "PLMS"
+        self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+        self.simple_txt2img["sampler_index"] = "DDIM"
+        self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+
+    def test_txt2img_multiple_batches_performed(self):
+        self.simple_txt2img["n_iter"] = 2
+        self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py
new file mode 100644
index 00000000..765470c9
--- /dev/null
+++ b/test/basic_features/utils_test.py
@@ -0,0 +1,53 @@
+import unittest
+import requests
+
+class UtilsTests(unittest.TestCase):
+  def setUp(self):
+    self.url_options = "http://localhost:7860/sdapi/v1/options"
+    self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
+    self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
+    self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
+    self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
+    self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
+    self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
+    self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
+    self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
+    self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories"
+    self.url_artists = "http://localhost:7860/sdapi/v1/artists"
+
+  def test_options_get(self):
+    self.assertEqual(requests.get(self.url_options).status_code, 200)
+
+  def test_cmd_flags(self):
+    self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
+
+  def test_samplers(self):
+    self.assertEqual(requests.get(self.url_samplers).status_code, 200)
+
+  def test_upscalers(self):
+    self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
+
+  def test_sd_models(self):
+    self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
+
+  def test_hypernetworks(self):
+    self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
+
+  def test_face_restorers(self):
+    self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
+  
+  def test_realesrgan_models(self):
+    self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
+  
+  def test_prompt_styles(self):
+    self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
+  
+  def test_artist_categories(self):
+    self.assertEqual(requests.get(self.url_artist_categories).status_code, 200)
+
+  def test_artists(self):
+    self.assertEqual(requests.get(self.url_artists).status_code, 200)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/test/server_poll.py b/test/server_poll.py
new file mode 100644
index 00000000..d4df697b
--- /dev/null
+++ b/test/server_poll.py
@@ -0,0 +1,24 @@
+import unittest
+import requests
+import time
+
+
+def run_tests(proc, test_dir):
+    timeout_threshold = 240
+    start_time = time.time()
+    while time.time()-start_time < timeout_threshold:
+        try:
+            requests.head("http://localhost:7860/")
+            break
+        except requests.exceptions.ConnectionError:
+            if proc.poll() is not None:
+                break
+    if proc.poll() is None:
+        if test_dir is None:
+            test_dir = ""
+        suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test")
+        result = unittest.TextTestRunner(verbosity=2).run(suite)
+        return len(result.failures) + len(result.errors)
+    else:
+        print("Launch unsuccessful")
+        return 1
diff --git a/test/test_files/empty.pt b/test/test_files/empty.pt
new file mode 100644
index 00000000..c6ac59eb
Binary files /dev/null and b/test/test_files/empty.pt differ
diff --git a/test/test_files/img2img_basic.png b/test/test_files/img2img_basic.png
new file mode 100644
index 00000000..49a42048
Binary files /dev/null and b/test/test_files/img2img_basic.png differ
diff --git a/test/test_files/mask_basic.png b/test/test_files/mask_basic.png
new file mode 100644
index 00000000..0c2e9a68
Binary files /dev/null and b/test/test_files/mask_basic.png differ
diff --git a/v2-inference-v.yaml b/v2-inference-v.yaml
new file mode 100644
index 00000000..513cd635
--- /dev/null
+++ b/v2-inference-v.yaml
@@ -0,0 +1,68 @@
+model:
+  base_learning_rate: 1.0e-4
+  target: ldm.models.diffusion.ddpm.LatentDiffusion
+  params:
+    parameterization: "v"
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "jpg"
+    cond_stage_key: "txt"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: false
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False # we set this to false because this is an inference only config
+
+    unet_config:
+      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        use_checkpoint: True
+        use_fp16: True
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_head_channels: 64 # need to fix for flash-attn
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: 1
+        context_dim: 1024
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          #attn_type: "vanilla-xformers"
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+      params:
+        freeze: True
+        layer: "penultimate"
\ No newline at end of file
diff --git a/webui-macos-env.sh b/webui-macos-env.sh
new file mode 100644
index 00000000..95ca9c55
--- /dev/null
+++ b/webui-macos-env.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+####################################################################
+#                          macOS defaults                          #
+# Please modify webui-user.sh to change these instead of this file #
+####################################################################
+
+if [[ -x "$(command -v python3.10)" ]]
+then
+    python_cmd="python3.10"
+fi
+
+export install_dir="$HOME"
+export COMMANDLINE_ARGS="--skip-torch-cuda-test --no-half --use-cpu interrogate"
+export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
+export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
+export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
+export PYTORCH_ENABLE_MPS_FALLBACK=1
+
+####################################################################
diff --git a/webui-user.sh b/webui-user.sh
index 30646f5c..bfa53cb7 100644
--- a/webui-user.sh
+++ b/webui-user.sh
@@ -10,7 +10,7 @@
 #clone_dir="stable-diffusion-webui"
 
 # Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
-export COMMANDLINE_ARGS=""
+#export COMMANDLINE_ARGS=""
 
 # python3 executable
 #python_cmd="python3"
@@ -40,4 +40,7 @@ export COMMANDLINE_ARGS=""
 #export CODEFORMER_COMMIT_HASH=""
 #export BLIP_COMMIT_HASH=""
 
+# Uncomment to enable accelerated launch
+#export ACCELERATE="True"
+
 ###########################################
diff --git a/webui.bat b/webui.bat
index a38a28bb..d4d626e2 100644
--- a/webui.bat
+++ b/webui.bat
@@ -28,15 +28,27 @@ goto :show_stdout_stderr
 :activate_venv
 set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe"
 echo venv %PYTHON%
+if [%ACCELERATE%] == ["True"] goto :accelerate
 goto :launch
 
 :skip_venv
 
+:accelerate
+echo "Checking for accelerate"
+set ACCELERATE="%~dp0%VENV_DIR%\Scripts\accelerate.exe"
+if EXIST %ACCELERATE% goto :accelerate_launch
+
 :launch
 %PYTHON% launch.py %*
 pause
 exit /b
 
+:accelerate_launch
+echo "Accelerating"
+%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
+pause
+exit /b
+
 :show_stdout_stderr
 
 echo.
diff --git a/webui.py b/webui.py
index 177bef74..13375e71 100644
--- a/webui.py
+++ b/webui.py
@@ -1,15 +1,19 @@
 import os
+import sys
 import threading
 import time
 import importlib
 import signal
 import threading
 from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
 
+from modules import import_hook, errors
+from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
 from modules.paths import script_path
 
-from modules import devices, sd_samplers
+from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
 import modules.codeformer_model as codeformer
 import modules.extras
 import modules.face_restoration
@@ -21,70 +25,70 @@ import modules.paths
 import modules.scripts
 import modules.sd_hijack
 import modules.sd_models
-import modules.shared as shared
+import modules.sd_vae
 import modules.txt2img
+import modules.script_callbacks
 
 import modules.ui
-from modules import devices
 from modules import modelloader
-from modules.paths import script_path
 from modules.shared import cmd_opts
 import modules.hypernetworks.hypernetwork
 
-queue_lock = threading.Lock()
 
+if cmd_opts.server_name:
+    server_name = cmd_opts.server_name
+else:
+    server_name = "0.0.0.0" if cmd_opts.listen else None
 
-def wrap_queued_call(func):
-    def f(*args, **kwargs):
-        with queue_lock:
-            res = func(*args, **kwargs)
-
-        return res
-
-    return f
-
-
-def wrap_gradio_gpu_call(func, extra_outputs=None):
-    def f(*args, **kwargs):
-        devices.torch_gc()
-
-        shared.state.sampling_step = 0
-        shared.state.job_count = -1
-        shared.state.job_no = 0
-        shared.state.job_timestamp = shared.state.get_job_timestamp()
-        shared.state.current_latent = None
-        shared.state.current_image = None
-        shared.state.current_image_sampling_step = 0
-        shared.state.skipped = False
-        shared.state.interrupted = False
-        shared.state.textinfo = None
-
-        with queue_lock:
-            res = func(*args, **kwargs)
-
-        shared.state.job = ""
-        shared.state.job_count = 0
-
-        devices.torch_gc()
-
-        return res
-
-    return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
 
 def initialize():
+    extensions.list_extensions()
+    localization.list_localizations(cmd_opts.localizations_dir)
+
+    if cmd_opts.ui_debug_mode:
+        shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
+        modules.scripts.load_scripts()
+        return
+
     modelloader.cleanup_models()
     modules.sd_models.setup_model()
     codeformer.setup_model(cmd_opts.codeformer_models_path)
     gfpgan.setup_model(cmd_opts.gfpgan_models_path)
     shared.face_restorers.append(modules.face_restoration.FaceRestoration())
+
+    modelloader.list_builtin_upscalers()
+    modules.scripts.load_scripts()
     modelloader.load_upscalers()
 
-    modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
+    modules.sd_vae.refresh_vae_list()
 
-    shared.sd_model = modules.sd_models.load_model()
-    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
-    shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
+    try:
+        modules.sd_models.load_model()
+    except Exception as e:
+        errors.display(e, "loading stable diffusion model")
+        print("", file=sys.stderr)
+        print("Stable diffusion model failed to load, exiting", file=sys.stderr)
+        exit(1)
+
+    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
+    shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+    shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+    shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
     shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
+    shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+
+    if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
+
+        try:
+            if not os.path.exists(cmd_opts.tls_keyfile):
+                print("Invalid path to TLS keyfile given")
+            if not os.path.exists(cmd_opts.tls_certfile):
+                print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
+        except TypeError:
+            cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
+            print("TLS setup invalid, running webui without TLS")
+        else:
+            print("Running with TLS")
 
     # make the program just exit at ctrl+c without waiting for anything
     def sigint_handler(sig, frame):
@@ -94,68 +98,106 @@ def initialize():
     signal.signal(signal.SIGINT, sigint_handler)
 
 
+def setup_cors(app):
+    if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
+        app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
+    elif cmd_opts.cors_allow_origins:
+        app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
+    elif cmd_opts.cors_allow_origins_regex:
+        app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
+
+
 def create_api(app):
     from modules.api.api import Api
     api = Api(app, queue_lock)
     return api
 
+
 def wait_on_server(demo=None):
     while 1:
         time.sleep(0.5)
-        if demo and getattr(demo, 'do_restart', False):
+        if shared.state.need_restart:
+            shared.state.need_restart = False
             time.sleep(0.5)
             demo.close()
             time.sleep(0.5)
             break
 
+
 def api_only():
     initialize()
 
     app = FastAPI()
+    setup_cors(app)
     app.add_middleware(GZipMiddleware, minimum_size=1000)
     api = create_api(app)
 
+    modules.script_callbacks.app_started_callback(None, app)
+
     api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
 
 
-def webui(launch_api=False):
+def webui():
+    launch_api = cmd_opts.api
     initialize()
 
     while 1:
-        demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
+        if shared.opts.clean_temp_dir_at_start:
+            ui_tempdir.cleanup_tmpdr()
 
-        app, local_url, share_url = demo.launch(
+        shared.demo = modules.ui.create_ui()
+
+        app, local_url, share_url = shared.demo.queue(default_enabled=False).launch(
             share=cmd_opts.share,
-            server_name="0.0.0.0" if cmd_opts.listen else None,
+            server_name=server_name,
             server_port=cmd_opts.port,
+            ssl_keyfile=cmd_opts.tls_keyfile,
+            ssl_certfile=cmd_opts.tls_certfile,
             debug=cmd_opts.gradio_debug,
             auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
             inbrowser=cmd_opts.autolaunch,
             prevent_thread_lock=True
         )
+        # after initial launch, disable --autolaunch for subsequent restarts
+        cmd_opts.autolaunch = False
+
+        # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
+        # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
+        # running web ui and do whatever the attacker wants, including installing an extension and
+        # running its code. We disable this here. Suggested by RyotaK.
+        app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
+
+        setup_cors(app)
 
         app.add_middleware(GZipMiddleware, minimum_size=1000)
 
-        if (launch_api):
+        if launch_api:
             create_api(app)
 
-        wait_on_server(demo)
+        modules.script_callbacks.app_started_callback(shared.demo, app)
+        modules.script_callbacks.app_started_callback(shared.demo, app)
+
+        wait_on_server(shared.demo)
+        print('Restarting UI...')
 
         sd_samplers.set_samplers()
 
-        print('Reloading Custom Scripts')
-        modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
-        print('Reloading modules: modules.ui')
-        importlib.reload(modules.ui)
-        print('Refreshing Model List')
+        extensions.list_extensions()
+
+        localization.list_localizations(cmd_opts.localizations_dir)
+
+        modelloader.forbid_loaded_nonbuiltin_upscalers()
+        modules.scripts.reload_scripts()
+        modelloader.load_upscalers()
+
+        for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
+            importlib.reload(module)
+
         modules.sd_models.list_models()
-        print('Restarting Gradio')
 
 
-
-task = []
 if __name__ == "__main__":
     if cmd_opts.nowebui:
         api_only()
     else:
-        webui(cmd_opts.api)
+        webui()
diff --git a/webui.sh b/webui.sh
index a9f85d89..04ecbf76 100755
--- a/webui.sh
+++ b/webui.sh
@@ -1,8 +1,17 @@
-#!/bin/bash
+#!/usr/bin/env bash
 #################################################
 # Please do not make any changes to this file,  #
 # change the variables in webui-user.sh instead #
 #################################################
+
+# If run from macOS, load defaults from webui-macos-env.sh
+if [[ "$OSTYPE" == "darwin"* ]]; then
+    if [[ -f webui-macos-env.sh ]]
+        then
+        source ./webui-macos-env.sh
+    fi
+fi
+
 # Read variables from webui-user.sh
 # shellcheck source=/dev/null
 if [[ -f webui-user.sh ]]
@@ -46,6 +55,18 @@ then
     LAUNCH_SCRIPT="launch.py"
 fi
 
+# this script cannot be run as root by default
+can_run_as_root=0
+
+# read any command line flags to the webui.sh script
+while getopts "f" flag > /dev/null 2>&1
+do
+    case ${flag} in
+        f) can_run_as_root=1;;
+        *) break;;
+    esac
+done
+
 # Disable sentry logging
 export ERROR_REPORTING=FALSE
 
@@ -61,7 +82,7 @@ printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m"
 printf "\n%s\n" "${delimiter}"
 
 # Do not run as root
-if [[ $(id -u) -eq 0 ]]
+if [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]]
 then
     printf "\n%s\n" "${delimiter}"
     printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m"
@@ -102,15 +123,14 @@ then
     exit 1
 fi
 
-printf "\n%s\n" "${delimiter}"
-printf "Clone or update stable-diffusion-webui"
-printf "\n%s\n" "${delimiter}"
 cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...\e[0m" "${install_dir}"; exit 1; }
 if [[ -d "${clone_dir}" ]]
 then
     cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
-    "${GIT}" pull
 else
+    printf "\n%s\n" "${delimiter}"
+    printf "Clone stable-diffusion-webui"
+    printf "\n%s\n" "${delimiter}"
     "${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${clone_dir}"
     cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
 fi
@@ -135,7 +155,15 @@ else
     exit 1
 fi
 
-printf "\n%s\n" "${delimiter}"
-printf "Launching launch.py..."
-printf "\n%s\n" "${delimiter}"
-"${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
+then
+    printf "\n%s\n" "${delimiter}"
+    printf "Accelerating launch.py..."
+    printf "\n%s\n" "${delimiter}"
+    accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
+else
+    printf "\n%s\n" "${delimiter}"
+    printf "Launching launch.py..."
+    printf "\n%s\n" "${delimiter}"
+    "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+fi