Merge branch 'master' into features-to-readme
This commit is contained in:
commit
ec37f8a45f
32
.github/ISSUE_TEMPLATE/bug_report.md
vendored
32
.github/ISSUE_TEMPLATE/bug_report.md
vendored
|
@ -1,32 +0,0 @@
|
||||||
---
|
|
||||||
name: Bug report
|
|
||||||
about: Create a report to help us improve
|
|
||||||
title: ''
|
|
||||||
labels: bug-report
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Describe the bug**
|
|
||||||
A clear and concise description of what the bug is.
|
|
||||||
|
|
||||||
**To Reproduce**
|
|
||||||
Steps to reproduce the behavior:
|
|
||||||
1. Go to '...'
|
|
||||||
2. Click on '....'
|
|
||||||
3. Scroll down to '....'
|
|
||||||
4. See error
|
|
||||||
|
|
||||||
**Expected behavior**
|
|
||||||
A clear and concise description of what you expected to happen.
|
|
||||||
|
|
||||||
**Screenshots**
|
|
||||||
If applicable, add screenshots to help explain your problem.
|
|
||||||
|
|
||||||
**Desktop (please complete the following information):**
|
|
||||||
- OS: [e.g. Windows, Linux]
|
|
||||||
- Browser [e.g. chrome, safari]
|
|
||||||
- Commit revision [looks like this: e68484500f76a33ba477d5a99340ab30451e557b; can be seen when launching webui.bat, or obtained manually by running `git rev-parse HEAD`]
|
|
||||||
|
|
||||||
**Additional context**
|
|
||||||
Add any other context about the problem here.
|
|
83
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
83
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
name: Bug Report
|
||||||
|
description: You think somethings is broken in the UI
|
||||||
|
title: "[Bug]: "
|
||||||
|
labels: ["bug-report"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: Is there an existing issue for this?
|
||||||
|
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
|
||||||
|
options:
|
||||||
|
- label: I have searched the existing issues and checked the recent builds/commits
|
||||||
|
required: true
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
|
||||||
|
- type: textarea
|
||||||
|
id: what-did
|
||||||
|
attributes:
|
||||||
|
label: What happened?
|
||||||
|
description: Tell us what happened in a very clear and simple way
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: steps
|
||||||
|
attributes:
|
||||||
|
label: Steps to reproduce the problem
|
||||||
|
description: Please provide us with precise step by step information on how to reproduce the bug
|
||||||
|
value: |
|
||||||
|
1. Go to ....
|
||||||
|
2. Press ....
|
||||||
|
3. ...
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: what-should
|
||||||
|
attributes:
|
||||||
|
label: What should have happened?
|
||||||
|
description: tell what you think the normal behavior should be
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: input
|
||||||
|
id: commit
|
||||||
|
attributes:
|
||||||
|
label: Commit where the problem happens
|
||||||
|
description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: dropdown
|
||||||
|
id: platforms
|
||||||
|
attributes:
|
||||||
|
label: What platforms do you use to access UI ?
|
||||||
|
multiple: true
|
||||||
|
options:
|
||||||
|
- Windows
|
||||||
|
- Linux
|
||||||
|
- MacOS
|
||||||
|
- iOS
|
||||||
|
- Android
|
||||||
|
- Other/Cloud
|
||||||
|
- type: dropdown
|
||||||
|
id: browsers
|
||||||
|
attributes:
|
||||||
|
label: What browsers do you use to access the UI ?
|
||||||
|
multiple: true
|
||||||
|
options:
|
||||||
|
- Mozilla Firefox
|
||||||
|
- Google Chrome
|
||||||
|
- Brave
|
||||||
|
- Apple Safari
|
||||||
|
- Microsoft Edge
|
||||||
|
- type: textarea
|
||||||
|
id: cmdargs
|
||||||
|
attributes:
|
||||||
|
label: Command Line Arguments
|
||||||
|
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
|
||||||
|
render: Shell
|
||||||
|
- type: textarea
|
||||||
|
id: misc
|
||||||
|
attributes:
|
||||||
|
label: Additional information, context and logs
|
||||||
|
description: Please provide us with any relevant additional info, context or log output.
|
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
blank_issues_enabled: false
|
||||||
|
contact_links:
|
||||||
|
- name: WebUI Community Support
|
||||||
|
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
|
||||||
|
about: Please ask and answer questions here.
|
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
|
@ -1,20 +0,0 @@
|
||||||
---
|
|
||||||
name: Feature request
|
|
||||||
about: Suggest an idea for this project
|
|
||||||
title: ''
|
|
||||||
labels: 'suggestion'
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Is your feature request related to a problem? Please describe.**
|
|
||||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
|
||||||
|
|
||||||
**Describe the solution you'd like**
|
|
||||||
A clear and concise description of what you want to happen.
|
|
||||||
|
|
||||||
**Describe alternatives you've considered**
|
|
||||||
A clear and concise description of any alternative solutions or features you've considered.
|
|
||||||
|
|
||||||
**Additional context**
|
|
||||||
Add any other context or screenshots about the feature request here.
|
|
40
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
40
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
name: Feature request
|
||||||
|
description: Suggest an idea for this project
|
||||||
|
title: "[Feature Request]: "
|
||||||
|
labels: ["suggestion"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: Is there an existing issue for this?
|
||||||
|
description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
|
||||||
|
options:
|
||||||
|
- label: I have searched the existing issues and checked the recent builds/commits
|
||||||
|
required: true
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
*Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
|
||||||
|
- type: textarea
|
||||||
|
id: feature
|
||||||
|
attributes:
|
||||||
|
label: What would your feature do ?
|
||||||
|
description: Tell us about your feature in a very clear and simple way, and what problem it would solve
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: workflow
|
||||||
|
attributes:
|
||||||
|
label: Proposed workflow
|
||||||
|
description: Please provide us with step by step information on how you'd like the feature to be accessed and used
|
||||||
|
value: |
|
||||||
|
1. Go to ....
|
||||||
|
2. Press ....
|
||||||
|
3. ...
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: misc
|
||||||
|
attributes:
|
||||||
|
label: Additional information
|
||||||
|
description: Add any other context or screenshots about the feature request here.
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -27,3 +27,4 @@ __pycache__
|
||||||
notification.mp3
|
notification.mp3
|
||||||
/SwinIR
|
/SwinIR
|
||||||
/textual_inversion
|
/textual_inversion
|
||||||
|
.vscode
|
|
@ -82,8 +82,8 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
||||||
- Use VAEs
|
- Use VAEs
|
||||||
- Estimated completion time in progress bar
|
- Estimated completion time in progress bar
|
||||||
- API
|
- API
|
||||||
- Support for dedicated inpainting model by RunwayML.
|
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
||||||
|
- Aesthetic Gradients, a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
||||||
|
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
|
|
|
@ -523,7 +523,6 @@ Affandi,0.7170285,nudity
|
||||||
Diane Arbus,0.655138,digipa-high-impact
|
Diane Arbus,0.655138,digipa-high-impact
|
||||||
Joseph Ducreux,0.65247905,digipa-high-impact
|
Joseph Ducreux,0.65247905,digipa-high-impact
|
||||||
Berthe Morisot,0.7165984,fineart
|
Berthe Morisot,0.7165984,fineart
|
||||||
Hilma AF Klint,0.71643853,scribbles
|
|
||||||
Hilma af Klint,0.71643853,scribbles
|
Hilma af Klint,0.71643853,scribbles
|
||||||
Filippino Lippi,0.7163017,fineart
|
Filippino Lippi,0.7163017,fineart
|
||||||
Leonid Afremov,0.7163005,fineart
|
Leonid Afremov,0.7163005,fineart
|
||||||
|
@ -738,14 +737,12 @@ Abraham Mignon,0.60605425,fineart
|
||||||
Albert Bloch,0.69573116,nudity
|
Albert Bloch,0.69573116,nudity
|
||||||
Charles Dana Gibson,0.67155975,fineart
|
Charles Dana Gibson,0.67155975,fineart
|
||||||
Alexandre-Évariste Fragonard,0.6507174,fineart
|
Alexandre-Évariste Fragonard,0.6507174,fineart
|
||||||
Alexandre-Évariste Fragonard,0.6507174,fineart
|
|
||||||
Ernst Fuchs,0.6953538,nudity
|
Ernst Fuchs,0.6953538,nudity
|
||||||
Alfredo Jaar,0.6952965,digipa-high-impact
|
Alfredo Jaar,0.6952965,digipa-high-impact
|
||||||
Judy Chicago,0.6952246,weird
|
Judy Chicago,0.6952246,weird
|
||||||
Frans van Mieris the Younger,0.6951849,fineart
|
Frans van Mieris the Younger,0.6951849,fineart
|
||||||
Aertgen van Leyden,0.6951305,fineart
|
Aertgen van Leyden,0.6951305,fineart
|
||||||
Emily Carr,0.69512105,fineart
|
Emily Carr,0.69512105,fineart
|
||||||
Frances Macdonald,0.6950408,scribbles
|
|
||||||
Frances MacDonald,0.6950408,scribbles
|
Frances MacDonald,0.6950408,scribbles
|
||||||
Hannah Höch,0.69495845,scribbles
|
Hannah Höch,0.69495845,scribbles
|
||||||
Gillis Rombouts,0.58770025,fineart
|
Gillis Rombouts,0.58770025,fineart
|
||||||
|
@ -895,7 +892,6 @@ Richard McGuire,0.6820089,scribbles
|
||||||
Anni Albers,0.65708244,digipa-high-impact
|
Anni Albers,0.65708244,digipa-high-impact
|
||||||
Aleksey Savrasov,0.65207493,fineart
|
Aleksey Savrasov,0.65207493,fineart
|
||||||
Wayne Barlowe,0.6537874,fineart
|
Wayne Barlowe,0.6537874,fineart
|
||||||
Giorgio De Chirico,0.6815907,fineart
|
|
||||||
Giorgio de Chirico,0.6815907,fineart
|
Giorgio de Chirico,0.6815907,fineart
|
||||||
Ernest Procter,0.6815795,fineart
|
Ernest Procter,0.6815795,fineart
|
||||||
Adriaen Brouwer,0.6815058,fineart
|
Adriaen Brouwer,0.6815058,fineart
|
||||||
|
@ -1241,7 +1237,6 @@ Betty Churcher,0.65387225,fineart
|
||||||
Claes Corneliszoon Moeyaert,0.65386075,fineart
|
Claes Corneliszoon Moeyaert,0.65386075,fineart
|
||||||
David Bomberg,0.6537477,fineart
|
David Bomberg,0.6537477,fineart
|
||||||
Abraham Bosschaert,0.6535562,fineart
|
Abraham Bosschaert,0.6535562,fineart
|
||||||
Giuseppe De Nittis,0.65354455,fineart
|
|
||||||
Giuseppe de Nittis,0.65354455,fineart
|
Giuseppe de Nittis,0.65354455,fineart
|
||||||
John La Farge,0.65342575,fineart
|
John La Farge,0.65342575,fineart
|
||||||
Frits Thaulow,0.65341854,fineart
|
Frits Thaulow,0.65341854,fineart
|
||||||
|
@ -1522,7 +1517,6 @@ Gertrude Harvey,0.5903887,fineart
|
||||||
Grant Wood,0.6266253,fineart
|
Grant Wood,0.6266253,fineart
|
||||||
Fyodor Vasilyev,0.5234919,digipa-med-impact
|
Fyodor Vasilyev,0.5234919,digipa-med-impact
|
||||||
Cagnaccio di San Pietro,0.6261671,fineart
|
Cagnaccio di San Pietro,0.6261671,fineart
|
||||||
Cagnaccio Di San Pietro,0.6261671,fineart
|
|
||||||
Doris Boulton-Maude,0.62593174,fineart
|
Doris Boulton-Maude,0.62593174,fineart
|
||||||
Adolf Hirémy-Hirschl,0.5946784,fineart
|
Adolf Hirémy-Hirschl,0.5946784,fineart
|
||||||
Harold von Schmidt,0.6256755,fineart
|
Harold von Schmidt,0.6256755,fineart
|
||||||
|
@ -2411,7 +2405,6 @@ Hermann Feierabend,0.5346168,digipa-high-impact
|
||||||
Antonio Donghi,0.4610982,digipa-low-impact
|
Antonio Donghi,0.4610982,digipa-low-impact
|
||||||
Adonna Khare,0.4858036,digipa-med-impact
|
Adonna Khare,0.4858036,digipa-med-impact
|
||||||
James Stokoe,0.5015107,digipa-med-impact
|
James Stokoe,0.5015107,digipa-med-impact
|
||||||
Art & Language,0.5341332,digipa-high-impact
|
|
||||||
Agustín Fernández,0.53403986,fineart
|
Agustín Fernández,0.53403986,fineart
|
||||||
Germán Londoño,0.5338712,fineart
|
Germán Londoño,0.5338712,fineart
|
||||||
Emmanuelle Moureaux,0.5335641,digipa-high-impact
|
Emmanuelle Moureaux,0.5335641,digipa-high-impact
|
||||||
|
|
|
|
@ -9,9 +9,38 @@ addEventListener('keydown', (event) => {
|
||||||
let minus = "ArrowDown"
|
let minus = "ArrowDown"
|
||||||
if (event.key != plus && event.key != minus) return;
|
if (event.key != plus && event.key != minus) return;
|
||||||
|
|
||||||
selectionStart = target.selectionStart;
|
let selectionStart = target.selectionStart;
|
||||||
selectionEnd = target.selectionEnd;
|
let selectionEnd = target.selectionEnd;
|
||||||
if(selectionStart == selectionEnd) return;
|
// If the user hasn't selected anything, let's select their current parenthesis block
|
||||||
|
if (selectionStart === selectionEnd) {
|
||||||
|
// Find opening parenthesis around current cursor
|
||||||
|
const before = target.value.substring(0, selectionStart);
|
||||||
|
let beforeParen = before.lastIndexOf("(");
|
||||||
|
if (beforeParen == -1) return;
|
||||||
|
let beforeParenClose = before.lastIndexOf(")");
|
||||||
|
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
||||||
|
beforeParen = before.lastIndexOf("(", beforeParen - 1);
|
||||||
|
beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find closing parenthesis around current cursor
|
||||||
|
const after = target.value.substring(selectionStart);
|
||||||
|
let afterParen = after.indexOf(")");
|
||||||
|
if (afterParen == -1) return;
|
||||||
|
let afterParenOpen = after.indexOf("(");
|
||||||
|
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
||||||
|
afterParen = after.indexOf(")", afterParen + 1);
|
||||||
|
afterParenOpen = after.indexOf("(", afterParenOpen + 1);
|
||||||
|
}
|
||||||
|
if (beforeParen === -1 || afterParen === -1) return;
|
||||||
|
|
||||||
|
// Set the selection to the text between the parenthesis
|
||||||
|
const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
|
||||||
|
const lastColon = parenContent.lastIndexOf(":");
|
||||||
|
selectionStart = beforeParen + 1;
|
||||||
|
selectionEnd = selectionStart + lastColon;
|
||||||
|
target.setSelectionRange(selectionStart, selectionEnd);
|
||||||
|
}
|
||||||
|
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
|
|
|
@ -91,6 +91,8 @@ titles = {
|
||||||
|
|
||||||
"Weighted sum": "Result = A * (1 - M) + B * M",
|
"Weighted sum": "Result = A * (1 - M) + B * M",
|
||||||
"Add difference": "Result = A + (B - C) * M",
|
"Add difference": "Result = A + (B - C) * M",
|
||||||
|
|
||||||
|
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,12 @@
|
||||||
// various functions for interation with ui.py not large enough to warrant putting them in separate files
|
// various functions for interation with ui.py not large enough to warrant putting them in separate files
|
||||||
|
|
||||||
|
function set_theme(theme){
|
||||||
|
gradioURL = window.location.href
|
||||||
|
if (!gradioURL.includes('?__theme=')) {
|
||||||
|
window.location.replace(gradioURL + '?__theme=' + theme);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function selected_gallery_index(){
|
function selected_gallery_index(){
|
||||||
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item')
|
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item')
|
||||||
var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2')
|
var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2')
|
||||||
|
|
49
launch.py
49
launch.py
|
@ -86,7 +86,24 @@ def git_clone(url, dir, name, commithash=None):
|
||||||
if commithash is not None:
|
if commithash is not None:
|
||||||
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||||
|
|
||||||
|
|
||||||
|
def version_check(commit):
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
|
||||||
|
if commit != "<none>" and commits['commit']['sha'] != commit:
|
||||||
|
print("--------------------------------------------------------")
|
||||||
|
print("| You are not up to date with the most recent release. |")
|
||||||
|
print("| Consider running `git pull` to update. |")
|
||||||
|
print("--------------------------------------------------------")
|
||||||
|
elif commits['commit']['sha'] == commit:
|
||||||
|
print("You are up to date with the most recent release.")
|
||||||
|
else:
|
||||||
|
print("Not a git clone, can't perform version check.")
|
||||||
|
except Exception as e:
|
||||||
|
print("versipm check failed",e)
|
||||||
|
|
||||||
|
|
||||||
def prepare_enviroment():
|
def prepare_enviroment():
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
@ -110,13 +127,14 @@ def prepare_enviroment():
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
|
|
||||||
args = shlex.split(commandline_args)
|
sys.argv += shlex.split(commandline_args)
|
||||||
|
|
||||||
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
|
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
||||||
args, reinstall_xformers = extract_arg(args, '--reinstall-xformers')
|
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
||||||
xformers = '--xformers' in args
|
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
|
||||||
deepdanbooru = '--deepdanbooru' in args
|
xformers = '--xformers' in sys.argv
|
||||||
ngrok = '--ngrok' in args
|
deepdanbooru = '--deepdanbooru' in sys.argv
|
||||||
|
ngrok = '--ngrok' in sys.argv
|
||||||
|
|
||||||
try:
|
try:
|
||||||
commit = run(f"{git} rev-parse HEAD").strip()
|
commit = run(f"{git} rev-parse HEAD").strip()
|
||||||
|
@ -125,7 +143,7 @@ def prepare_enviroment():
|
||||||
|
|
||||||
print(f"Python {sys.version}")
|
print(f"Python {sys.version}")
|
||||||
print(f"Commit hash: {commit}")
|
print(f"Commit hash: {commit}")
|
||||||
|
|
||||||
if not is_installed("torch") or not is_installed("torchvision"):
|
if not is_installed("torch") or not is_installed("torchvision"):
|
||||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
|
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
|
||||||
|
|
||||||
|
@ -138,9 +156,15 @@ def prepare_enviroment():
|
||||||
if not is_installed("clip"):
|
if not is_installed("clip"):
|
||||||
run_pip(f"install {clip_package}", "clip")
|
run_pip(f"install {clip_package}", "clip")
|
||||||
|
|
||||||
if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"):
|
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
|
if platform.python_version().startswith("3.10"):
|
||||||
|
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
|
||||||
|
else:
|
||||||
|
print("Installation of xformers is not supported in this version of Python.")
|
||||||
|
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||||
|
if not is_installed("xformers"):
|
||||||
|
exit(0)
|
||||||
elif platform.system() == "Linux":
|
elif platform.system() == "Linux":
|
||||||
run_pip("install xformers", "xformers")
|
run_pip("install xformers", "xformers")
|
||||||
|
|
||||||
|
@ -163,9 +187,10 @@ def prepare_enviroment():
|
||||||
|
|
||||||
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
||||||
|
|
||||||
sys.argv += args
|
if update_check:
|
||||||
|
version_check(commit)
|
||||||
if "--exit" in args:
|
|
||||||
|
if "--exit" in sys.argv:
|
||||||
print("Exiting because of --exit argument")
|
print("Exiting because of --exit argument")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
241
modules/aesthetic_clip.py
Normal file
241
modules/aesthetic_clip.py
Normal file
|
@ -0,0 +1,241 @@
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import html
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import optim
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
|
||||||
|
from tqdm.auto import tqdm, trange
|
||||||
|
from modules.shared import opts, device
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_images_in_folder(folder):
|
||||||
|
return [os.path.join(folder, f) for f in os.listdir(folder) if
|
||||||
|
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
|
||||||
|
|
||||||
|
|
||||||
|
def check_is_valid_image_file(filename):
|
||||||
|
return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))
|
||||||
|
|
||||||
|
|
||||||
|
def batched(dataset, total, n=1):
|
||||||
|
for ndx in range(0, total, n):
|
||||||
|
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
|
||||||
|
|
||||||
|
|
||||||
|
def iter_to_batched(iterable, n=1):
|
||||||
|
it = iter(iterable)
|
||||||
|
while True:
|
||||||
|
chunk = tuple(itertools.islice(it, n))
|
||||||
|
if not chunk:
|
||||||
|
return
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui():
|
||||||
|
import modules.ui
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
with gr.Accordion("Open for Clip Aesthetic!", open=False):
|
||||||
|
with gr.Row():
|
||||||
|
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
|
||||||
|
value=0.9)
|
||||||
|
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
|
||||||
|
placeholder="Aesthetic learning rate", value="0.0001")
|
||||||
|
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
|
||||||
|
aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
|
||||||
|
label="Aesthetic imgs embedding",
|
||||||
|
value="None")
|
||||||
|
|
||||||
|
modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
|
||||||
|
placeholder="This text is used to rotate the feature space of the imgs embs",
|
||||||
|
value="")
|
||||||
|
aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
|
||||||
|
value=0.1)
|
||||||
|
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
|
||||||
|
|
||||||
|
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
|
||||||
|
|
||||||
|
|
||||||
|
aesthetic_clip_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def aesthetic_clip():
|
||||||
|
global aesthetic_clip_model
|
||||||
|
|
||||||
|
if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
|
||||||
|
aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
|
||||||
|
aesthetic_clip_model.cpu()
|
||||||
|
|
||||||
|
return aesthetic_clip_model
|
||||||
|
|
||||||
|
|
||||||
|
def generate_imgs_embd(name, folder, batch_size):
|
||||||
|
model = aesthetic_clip().to(device)
|
||||||
|
processor = CLIPProcessor.from_pretrained(model.name_or_path)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
embs = []
|
||||||
|
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
|
||||||
|
desc=f"Generating embeddings for {name}"):
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
|
||||||
|
outputs = model.get_image_features(**inputs).cpu()
|
||||||
|
embs.append(torch.clone(outputs))
|
||||||
|
inputs.to("cpu")
|
||||||
|
del inputs, outputs
|
||||||
|
|
||||||
|
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# The generated embedding will be located here
|
||||||
|
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
|
||||||
|
torch.save(embs, path)
|
||||||
|
|
||||||
|
model.cpu()
|
||||||
|
del processor
|
||||||
|
del embs
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
res = f"""
|
||||||
|
Done generating embedding for {name}!
|
||||||
|
Aesthetic embedding saved to {html.escape(path)}
|
||||||
|
"""
|
||||||
|
shared.update_aesthetic_embeddings()
|
||||||
|
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
|
||||||
|
value="None"), \
|
||||||
|
gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
|
||||||
|
label="Imgs embedding",
|
||||||
|
value="None"), res, ""
|
||||||
|
|
||||||
|
|
||||||
|
def slerp(low, high, val):
|
||||||
|
low_norm = low / torch.norm(low, dim=1, keepdim=True)
|
||||||
|
high_norm = high / torch.norm(high, dim=1, keepdim=True)
|
||||||
|
omega = torch.acos((low_norm * high_norm).sum(1))
|
||||||
|
so = torch.sin(omega)
|
||||||
|
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class AestheticCLIP:
|
||||||
|
def __init__(self):
|
||||||
|
self.skip = False
|
||||||
|
self.aesthetic_steps = 0
|
||||||
|
self.aesthetic_weight = 0
|
||||||
|
self.aesthetic_lr = 0
|
||||||
|
self.slerp = False
|
||||||
|
self.aesthetic_text_negative = ""
|
||||||
|
self.aesthetic_slerp_angle = 0
|
||||||
|
self.aesthetic_imgs_text = ""
|
||||||
|
|
||||||
|
self.image_embs_name = None
|
||||||
|
self.image_embs = None
|
||||||
|
self.load_image_embs(None)
|
||||||
|
|
||||||
|
def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
|
||||||
|
aesthetic_slerp=True, aesthetic_imgs_text="",
|
||||||
|
aesthetic_slerp_angle=0.15,
|
||||||
|
aesthetic_text_negative=False):
|
||||||
|
self.aesthetic_imgs_text = aesthetic_imgs_text
|
||||||
|
self.aesthetic_slerp_angle = aesthetic_slerp_angle
|
||||||
|
self.aesthetic_text_negative = aesthetic_text_negative
|
||||||
|
self.slerp = aesthetic_slerp
|
||||||
|
self.aesthetic_lr = aesthetic_lr
|
||||||
|
self.aesthetic_weight = aesthetic_weight
|
||||||
|
self.aesthetic_steps = aesthetic_steps
|
||||||
|
self.load_image_embs(image_embs_name)
|
||||||
|
|
||||||
|
if self.image_embs_name is not None:
|
||||||
|
p.extra_generation_params.update({
|
||||||
|
"Aesthetic LR": aesthetic_lr,
|
||||||
|
"Aesthetic weight": aesthetic_weight,
|
||||||
|
"Aesthetic steps": aesthetic_steps,
|
||||||
|
"Aesthetic embedding": self.image_embs_name,
|
||||||
|
"Aesthetic slerp": aesthetic_slerp,
|
||||||
|
"Aesthetic text": aesthetic_imgs_text,
|
||||||
|
"Aesthetic text negative": aesthetic_text_negative,
|
||||||
|
"Aesthetic slerp angle": aesthetic_slerp_angle,
|
||||||
|
})
|
||||||
|
|
||||||
|
def set_skip(self, skip):
|
||||||
|
self.skip = skip
|
||||||
|
|
||||||
|
def load_image_embs(self, image_embs_name):
|
||||||
|
if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
|
||||||
|
image_embs_name = None
|
||||||
|
self.image_embs_name = None
|
||||||
|
if image_embs_name is not None and self.image_embs_name != image_embs_name:
|
||||||
|
self.image_embs_name = image_embs_name
|
||||||
|
self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
|
||||||
|
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
|
||||||
|
self.image_embs.requires_grad_(False)
|
||||||
|
|
||||||
|
def __call__(self, z, remade_batch_tokens):
|
||||||
|
if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
|
||||||
|
tokenizer = shared.sd_model.cond_stage_model.tokenizer
|
||||||
|
if not opts.use_old_emphasis_implementation:
|
||||||
|
remade_batch_tokens = [
|
||||||
|
[tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
|
||||||
|
remade_batch_tokens]
|
||||||
|
|
||||||
|
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||||
|
|
||||||
|
model = copy.deepcopy(aesthetic_clip()).to(device)
|
||||||
|
model.requires_grad_(True)
|
||||||
|
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
|
||||||
|
text_embs_2 = model.get_text_features(
|
||||||
|
**tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
|
||||||
|
if self.aesthetic_text_negative:
|
||||||
|
text_embs_2 = self.image_embs - text_embs_2
|
||||||
|
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
|
||||||
|
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
|
||||||
|
else:
|
||||||
|
img_embs = self.image_embs
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
|
||||||
|
# We optimize the model to maximize the similarity
|
||||||
|
optimizer = optim.Adam(
|
||||||
|
model.text_model.parameters(), lr=self.aesthetic_lr
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
|
||||||
|
text_embs = model.get_text_features(input_ids=tokens)
|
||||||
|
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
|
||||||
|
sim = text_embs @ img_embs.T
|
||||||
|
loss = -sim
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.mean().backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||||
|
if opts.CLIP_stop_at_last_layers > 1:
|
||||||
|
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||||
|
zn = model.text_model.final_layer_norm(zn)
|
||||||
|
else:
|
||||||
|
zn = zn.last_hidden_state
|
||||||
|
model.cpu()
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
|
||||||
|
if self.slerp:
|
||||||
|
z = slerp(z, zn, self.aesthetic_weight)
|
||||||
|
else:
|
||||||
|
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
|
||||||
|
|
||||||
|
return z
|
68
modules/api/api.py
Normal file
68
modules/api/api.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
from modules.api.processing import StableDiffusionProcessingAPI
|
||||||
|
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
|
||||||
|
from modules.sd_samplers import all_samplers
|
||||||
|
from modules.extras import run_pnginfo
|
||||||
|
import modules.shared as shared
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import Body, APIRouter, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel, Field, Json
|
||||||
|
import json
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
|
||||||
|
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
||||||
|
|
||||||
|
class TextToImageResponse(BaseModel):
|
||||||
|
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
|
parameters: Json
|
||||||
|
info: Json
|
||||||
|
|
||||||
|
|
||||||
|
class Api:
|
||||||
|
def __init__(self, app, queue_lock):
|
||||||
|
self.router = APIRouter()
|
||||||
|
self.app = app
|
||||||
|
self.queue_lock = queue_lock
|
||||||
|
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||||
|
|
||||||
|
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
|
||||||
|
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||||
|
|
||||||
|
if sampler_index is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||||
|
|
||||||
|
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||||
|
"sd_model": shared.sd_model,
|
||||||
|
"sampler_index": sampler_index[0],
|
||||||
|
"do_not_save_samples": True,
|
||||||
|
"do_not_save_grid": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||||
|
# Override object param
|
||||||
|
with self.queue_lock:
|
||||||
|
processed = process_images(p)
|
||||||
|
|
||||||
|
b64images = []
|
||||||
|
for i in processed.images:
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
i.save(buffer, format="png")
|
||||||
|
b64images.append(base64.b64encode(buffer.getvalue()))
|
||||||
|
|
||||||
|
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def img2imgapi(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def extrasapi(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def pnginfoapi(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def launch(self, server_name, port):
|
||||||
|
self.app.include_router(self.router)
|
||||||
|
uvicorn.run(self.app, host=server_name, port=port)
|
99
modules/api/processing.py
Normal file
99
modules/api/processing.py
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
from inflection import underscore
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
from modules.processing import StableDiffusionProcessingTxt2Img
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
API_NOT_ALLOWED = [
|
||||||
|
"self",
|
||||||
|
"kwargs",
|
||||||
|
"sd_model",
|
||||||
|
"outpath_samples",
|
||||||
|
"outpath_grids",
|
||||||
|
"sampler_index",
|
||||||
|
"do_not_save_samples",
|
||||||
|
"do_not_save_grid",
|
||||||
|
"extra_generation_params",
|
||||||
|
"overlay_images",
|
||||||
|
"do_not_reload_embeddings",
|
||||||
|
"seed_enable_extras",
|
||||||
|
"prompt_for_display",
|
||||||
|
"sampler_noise_scheduler_override",
|
||||||
|
"ddim_discretize"
|
||||||
|
]
|
||||||
|
|
||||||
|
class ModelDef(BaseModel):
|
||||||
|
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||||
|
|
||||||
|
field: str
|
||||||
|
field_alias: str
|
||||||
|
field_type: Any
|
||||||
|
field_value: Any
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticModelGenerator:
|
||||||
|
"""
|
||||||
|
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||||
|
source_data is a snapshot of the default values produced by the class
|
||||||
|
params are the names of the actual keys required by __init__
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = None,
|
||||||
|
class_instance = None,
|
||||||
|
additional_fields = None,
|
||||||
|
):
|
||||||
|
def field_type_generator(k, v):
|
||||||
|
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||||
|
# print(k, v.annotation, v.default)
|
||||||
|
field_type = v.annotation
|
||||||
|
|
||||||
|
return Optional[field_type]
|
||||||
|
|
||||||
|
def merge_class_params(class_):
|
||||||
|
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||||
|
parameters = {}
|
||||||
|
for classes in all_classes:
|
||||||
|
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
|
self._model_name = model_name
|
||||||
|
self._class_data = merge_class_params(class_instance)
|
||||||
|
self._model_def = [
|
||||||
|
ModelDef(
|
||||||
|
field=underscore(k),
|
||||||
|
field_alias=k,
|
||||||
|
field_type=field_type_generator(k, v),
|
||||||
|
field_value=v.default
|
||||||
|
)
|
||||||
|
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||||
|
]
|
||||||
|
|
||||||
|
for fields in additional_fields:
|
||||||
|
self._model_def.append(ModelDef(
|
||||||
|
field=underscore(fields["key"]),
|
||||||
|
field_alias=fields["key"],
|
||||||
|
field_type=fields["type"],
|
||||||
|
field_value=fields["default"]))
|
||||||
|
|
||||||
|
def generate_model(self):
|
||||||
|
"""
|
||||||
|
Creates a pydantic BaseModel
|
||||||
|
from the json and overrides provided at initialization
|
||||||
|
"""
|
||||||
|
fields = {
|
||||||
|
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
|
||||||
|
}
|
||||||
|
DynamicModel = create_model(self._model_name, **fields)
|
||||||
|
DynamicModel.__config__.allow_population_by_field_name = True
|
||||||
|
DynamicModel.__config__.allow_mutation = True
|
||||||
|
return DynamicModel
|
||||||
|
|
||||||
|
StableDiffusionProcessingAPI = PydanticModelGenerator(
|
||||||
|
"StableDiffusionProcessingTxt2Img",
|
||||||
|
StableDiffusionProcessingTxt2Img,
|
||||||
|
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||||
|
).generate_model()
|
|
@ -157,8 +157,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o
|
||||||
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
|
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
|
||||||
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
||||||
for weight, tag in unsorted_tags_in_theshold:
|
for weight, tag in unsorted_tags_in_theshold:
|
||||||
# note: tag_outformat will still have a colon if include_ranks is True
|
tag_outformat = tag
|
||||||
tag_outformat = tag.replace(':', ' ')
|
|
||||||
if use_spaces:
|
if use_spaces:
|
||||||
tag_outformat = tag_outformat.replace('_', ' ')
|
tag_outformat = tag_outformat.replace('_', ' ')
|
||||||
if use_escape:
|
if use_escape:
|
||||||
|
|
|
@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||||
|
|
||||||
if input_dir == '':
|
if input_dir == '':
|
||||||
return outputs, "Please select an input directory.", ''
|
return outputs, "Please select an input directory.", ''
|
||||||
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
|
image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
|
||||||
for img in image_list:
|
for img in image_list:
|
||||||
image = Image.open(img)
|
try:
|
||||||
|
image = Image.open(img)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
imageArr.append(image)
|
imageArr.append(image)
|
||||||
imageNameArr.append(img)
|
imageNameArr.append(img)
|
||||||
else:
|
else:
|
||||||
|
@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||||
|
|
||||||
while len(cached_images) > 2:
|
while len(cached_images) > 2:
|
||||||
del cached_images[next(iter(cached_images.keys()))]
|
del cached_images[next(iter(cached_images.keys()))]
|
||||||
|
|
||||||
|
if opts.use_original_name_batch and image_name != None:
|
||||||
|
basename = os.path.splitext(os.path.basename(image_name))[0]
|
||||||
|
else:
|
||||||
|
basename = ''
|
||||||
|
|
||||||
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
|
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
||||||
forced_filename=image_name if opts.use_original_name_batch else None)
|
|
||||||
|
|
||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
image.info = existing_pnginfo
|
image.info = existing_pnginfo
|
||||||
|
@ -216,8 +223,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
||||||
if theta_func1:
|
if theta_func1:
|
||||||
for key in tqdm.tqdm(theta_1.keys()):
|
for key in tqdm.tqdm(theta_1.keys()):
|
||||||
if 'model' in key:
|
if 'model' in key:
|
||||||
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
if key in theta_2:
|
||||||
theta_1[key] = theta_func1(theta_1[key], t2)
|
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
||||||
|
theta_1[key] = theta_func1(theta_1[key], t2)
|
||||||
|
else:
|
||||||
|
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||||
del theta_2, teritary_model
|
del theta_2, teritary_model
|
||||||
|
|
||||||
for key in tqdm.tqdm(theta_0.keys()):
|
for key in tqdm.tqdm(theta_0.keys()):
|
||||||
|
|
|
@ -4,13 +4,22 @@ import gradio as gr
|
||||||
from modules.shared import script_path
|
from modules.shared import script_path
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
|
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
|
||||||
re_param = re.compile(re_param_code)
|
re_param = re.compile(re_param_code)
|
||||||
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
||||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||||
type_of_gr_update = type(gr.update())
|
type_of_gr_update = type(gr.update())
|
||||||
|
|
||||||
|
|
||||||
|
def quote(text):
|
||||||
|
if ',' not in str(text):
|
||||||
|
return text
|
||||||
|
|
||||||
|
text = str(text)
|
||||||
|
text = text.replace('\\', '\\\\')
|
||||||
|
text = text.replace('"', '\\"')
|
||||||
|
return f'"{text}"'
|
||||||
|
|
||||||
def parse_generation_parameters(x: str):
|
def parse_generation_parameters(x: str):
|
||||||
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
||||||
```
|
```
|
||||||
|
@ -45,11 +54,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||||
else:
|
else:
|
||||||
prompt += ("" if prompt == "" else "\n") + line
|
prompt += ("" if prompt == "" else "\n") + line
|
||||||
|
|
||||||
if len(prompt) > 0:
|
res["Prompt"] = prompt
|
||||||
res["Prompt"] = prompt
|
res["Negative prompt"] = negative_prompt
|
||||||
|
|
||||||
if len(negative_prompt) > 0:
|
|
||||||
res["Negative prompt"] = negative_prompt
|
|
||||||
|
|
||||||
for k, v in re_param.findall(lastline):
|
for k, v in re_param.findall(lastline):
|
||||||
m = re_imagesize.match(v)
|
m = re_imagesize.match(v)
|
||||||
|
@ -86,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
valtype = type(output.value)
|
valtype = type(output.value)
|
||||||
val = valtype(v)
|
|
||||||
|
if valtype == bool and v == "False":
|
||||||
|
val = False
|
||||||
|
else:
|
||||||
|
val = valtype(v)
|
||||||
|
|
||||||
res.append(gr.update(value=val))
|
res.append(gr.update(value=val))
|
||||||
except Exception:
|
except Exception:
|
||||||
res.append(gr.update())
|
res.append(gr.update())
|
||||||
|
|
|
@ -22,25 +22,67 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None):
|
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.linear1 = torch.nn.Linear(dim, dim * 2)
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
self.linear2 = torch.nn.Linear(dim * 2, dim)
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||||
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||||
|
|
||||||
|
linears = []
|
||||||
|
for i in range(len(layer_structure) - 1):
|
||||||
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
|
if activation_func == "relu":
|
||||||
|
linears.append(torch.nn.ReLU())
|
||||||
|
elif activation_func == "leakyrelu":
|
||||||
|
linears.append(torch.nn.LeakyReLU())
|
||||||
|
elif activation_func == 'linear' or activation_func is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||||
|
|
||||||
|
if add_layer_norm:
|
||||||
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
|
self.linear = torch.nn.Sequential(*linears)
|
||||||
|
|
||||||
if state_dict is not None:
|
if state_dict is not None:
|
||||||
self.load_state_dict(state_dict, strict=True)
|
self.fix_old_state_dict(state_dict)
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
else:
|
else:
|
||||||
|
for layer in self.linear:
|
||||||
self.linear1.weight.data.normal_(mean=0.0, std=0.01)
|
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||||
self.linear1.bias.data.zero_()
|
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
self.linear2.weight.data.normal_(mean=0.0, std=0.01)
|
layer.bias.data.zero_()
|
||||||
self.linear2.bias.data.zero_()
|
|
||||||
|
|
||||||
self.to(devices.device)
|
self.to(devices.device)
|
||||||
|
|
||||||
|
def fix_old_state_dict(self, state_dict):
|
||||||
|
changes = {
|
||||||
|
'linear1.bias': 'linear.0.bias',
|
||||||
|
'linear1.weight': 'linear.0.weight',
|
||||||
|
'linear2.bias': 'linear.1.bias',
|
||||||
|
'linear2.weight': 'linear.1.weight',
|
||||||
|
}
|
||||||
|
|
||||||
|
for fr, to in changes.items():
|
||||||
|
x = state_dict.get(fr, None)
|
||||||
|
if x is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
del state_dict[fr]
|
||||||
|
state_dict[to] = x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x + (self.linear2(self.linear1(x))) * self.multiplier
|
return x + self.linear(x) * self.multiplier
|
||||||
|
|
||||||
|
def trainables(self):
|
||||||
|
layer_structure = []
|
||||||
|
for layer in self.linear:
|
||||||
|
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||||
|
layer_structure += [layer.weight, layer.bias]
|
||||||
|
return layer_structure
|
||||||
|
|
||||||
|
|
||||||
def apply_strength(value=None):
|
def apply_strength(value=None):
|
||||||
|
@ -51,16 +93,22 @@ class Hypernetwork:
|
||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, name=None, enable_sizes=None):
|
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.sd_checkpoint = None
|
self.sd_checkpoint = None
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
|
self.layer_structure = layer_structure
|
||||||
|
self.add_layer_norm = add_layer_norm
|
||||||
|
self.activation_func = activation_func
|
||||||
|
|
||||||
for size in enable_sizes or []:
|
for size in enable_sizes or []:
|
||||||
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
|
self.layers[size] = (
|
||||||
|
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
|
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
|
)
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
res = []
|
res = []
|
||||||
|
@ -68,7 +116,7 @@ class Hypernetwork:
|
||||||
for k, layers in self.layers.items():
|
for k, layers in self.layers.items():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.train()
|
layer.train()
|
||||||
res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
|
res += layer.trainables()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -80,6 +128,9 @@ class Hypernetwork:
|
||||||
|
|
||||||
state_dict['step'] = self.step
|
state_dict['step'] = self.step
|
||||||
state_dict['name'] = self.name
|
state_dict['name'] = self.name
|
||||||
|
state_dict['layer_structure'] = self.layer_structure
|
||||||
|
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||||
|
state_dict['activation_func'] = self.activation_func
|
||||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||||
|
|
||||||
|
@ -92,9 +143,16 @@ class Hypernetwork:
|
||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu')
|
state_dict = torch.load(filename, map_location='cpu')
|
||||||
|
|
||||||
|
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||||
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
|
self.activation_func = state_dict.get('activation_func', None)
|
||||||
|
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
|
self.layers[size] = (
|
||||||
|
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
|
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
|
)
|
||||||
|
|
||||||
self.name = state_dict.get('name', self.name)
|
self.name = state_dict.get('name', self.name)
|
||||||
self.step = state_dict.get('step', 0)
|
self.step = state_dict.get('step', 0)
|
||||||
|
@ -196,7 +254,11 @@ def stack_conds(conds):
|
||||||
|
|
||||||
return torch.stack(conds)
|
return torch.stack(conds)
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
||||||
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
|
from modules import images
|
||||||
|
|
||||||
assert hypernetwork_name, 'hypernetwork not selected'
|
assert hypernetwork_name, 'hypernetwork not selected'
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
|
@ -225,8 +287,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
@ -240,6 +301,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
last_saved_image = "<none>"
|
last_saved_image = "<none>"
|
||||||
|
forced_filename = "<none>"
|
||||||
|
|
||||||
ititial_step = hypernetwork.step or 0
|
ititial_step = hypernetwork.step or 0
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
|
@ -261,7 +323,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
||||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||||
loss = shared.sd_model(x, c)[0]
|
loss = shared.sd_model(x, c)[0]
|
||||||
del x
|
del x
|
||||||
|
@ -287,7 +349,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
})
|
})
|
||||||
|
|
||||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||||
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
@ -323,7 +386,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
image.save(last_saved_image)
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
shared.state.job_no = hypernetwork.step
|
shared.state.job_no = hypernetwork.step
|
||||||
|
@ -333,7 +396,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
Loss: {mean_loss:.7f}<br/>
|
Loss: {mean_loss:.7f}<br/>
|
||||||
Step: {hypernetwork.step}<br/>
|
Step: {hypernetwork.step}<br/>
|
||||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
@ -9,11 +10,21 @@ from modules import sd_hijack, shared, devices
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
if not overwrite_old:
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
|
if type(layer_structure) == str:
|
||||||
|
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||||
|
|
||||||
|
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||||
|
name=name,
|
||||||
|
enable_sizes=[int(x) for x in enable_sizes],
|
||||||
|
layer_structure=layer_structure,
|
||||||
|
add_layer_norm=add_layer_norm,
|
||||||
|
activation_func=activation_func,
|
||||||
|
)
|
||||||
hypernet.save(fn)
|
hypernet.save(fn)
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
|
|
@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args):
|
||||||
processed_image.save(os.path.join(output_dir, filename))
|
processed_image.save(os.path.join(output_dir, filename))
|
||||||
|
|
||||||
|
|
||||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
|
||||||
is_inpaint = mode == 1
|
is_inpaint = mode == 1
|
||||||
is_batch = mode == 2
|
is_batch = mode == 2
|
||||||
|
|
||||||
|
@ -109,6 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
||||||
inpainting_mask_invert=inpainting_mask_invert,
|
inpainting_mask_invert=inpainting_mask_invert,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
|
||||||
|
|
||||||
if shared.cmd_opts.enable_console_prompts:
|
if shared.cmd_opts.enable_console_prompts:
|
||||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
|
|
|
@ -28,9 +28,11 @@ class InterrogateModels:
|
||||||
clip_preprocess = None
|
clip_preprocess = None
|
||||||
categories = None
|
categories = None
|
||||||
dtype = None
|
dtype = None
|
||||||
|
running_on_cpu = None
|
||||||
|
|
||||||
def __init__(self, content_dir):
|
def __init__(self, content_dir):
|
||||||
self.categories = []
|
self.categories = []
|
||||||
|
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||||
|
|
||||||
if os.path.exists(content_dir):
|
if os.path.exists(content_dir):
|
||||||
for filename in os.listdir(content_dir):
|
for filename in os.listdir(content_dir):
|
||||||
|
@ -53,7 +55,11 @@ class InterrogateModels:
|
||||||
def load_clip_model(self):
|
def load_clip_model(self):
|
||||||
import clip
|
import clip
|
||||||
|
|
||||||
model, preprocess = clip.load(clip_model_name)
|
if self.running_on_cpu:
|
||||||
|
model, preprocess = clip.load(clip_model_name, device="cpu")
|
||||||
|
else:
|
||||||
|
model, preprocess = clip.load(clip_model_name)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model = model.to(devices.device_interrogate)
|
model = model.to(devices.device_interrogate)
|
||||||
|
|
||||||
|
@ -62,14 +68,14 @@ class InterrogateModels:
|
||||||
def load(self):
|
def load(self):
|
||||||
if self.blip_model is None:
|
if self.blip_model is None:
|
||||||
self.blip_model = self.load_blip_model()
|
self.blip_model = self.load_blip_model()
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||||
self.blip_model = self.blip_model.half()
|
self.blip_model = self.blip_model.half()
|
||||||
|
|
||||||
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
||||||
|
|
||||||
if self.clip_model is None:
|
if self.clip_model is None:
|
||||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||||
self.clip_model = self.clip_model.half()
|
self.clip_model = self.clip_model.half()
|
||||||
|
|
||||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||||
|
|
|
@ -9,9 +9,10 @@ from PIL import Image, ImageFilter, ImageOps
|
||||||
import random
|
import random
|
||||||
import cv2
|
import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
@ -51,9 +52,15 @@ def get_correct_sampler(p):
|
||||||
return sd_samplers.samplers
|
return sd_samplers.samplers
|
||||||
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
||||||
return sd_samplers.samplers_for_img2img
|
return sd_samplers.samplers_for_img2img
|
||||||
|
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
|
||||||
|
return sd_samplers.samplers
|
||||||
|
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing():
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None, do_not_reload_embeddings=False):
|
"""
|
||||||
|
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
|
||||||
self.sd_model = sd_model
|
self.sd_model = sd_model
|
||||||
self.outpath_samples: str = outpath_samples
|
self.outpath_samples: str = outpath_samples
|
||||||
self.outpath_grids: str = outpath_grids
|
self.outpath_grids: str = outpath_grids
|
||||||
|
@ -86,10 +93,10 @@ class StableDiffusionProcessing:
|
||||||
self.denoising_strength: float = 0
|
self.denoising_strength: float = 0
|
||||||
self.sampler_noise_scheduler_override = None
|
self.sampler_noise_scheduler_override = None
|
||||||
self.ddim_discretize = opts.ddim_discretize
|
self.ddim_discretize = opts.ddim_discretize
|
||||||
self.s_churn = opts.s_churn
|
self.s_churn = s_churn or opts.s_churn
|
||||||
self.s_tmin = opts.s_tmin
|
self.s_tmin = s_tmin or opts.s_tmin
|
||||||
self.s_tmax = float('inf') # not representable as a standard ui option
|
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||||
self.s_noise = opts.s_noise
|
self.s_noise = s_noise or opts.s_noise
|
||||||
|
|
||||||
if not seed_enable_extras:
|
if not seed_enable_extras:
|
||||||
self.subseed = -1
|
self.subseed = -1
|
||||||
|
@ -97,6 +104,7 @@ class StableDiffusionProcessing:
|
||||||
self.seed_resize_from_h = 0
|
self.seed_resize_from_h = 0
|
||||||
self.seed_resize_from_w = 0
|
self.seed_resize_from_w = 0
|
||||||
|
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -296,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
|
"Hypernet": (None if shared.loaded_hypernetwork is None else os.path.splitext(os.path.basename(shared.loaded_hypernetwork.filename))[0]),
|
||||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||||
|
@ -310,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
||||||
|
|
||||||
|
@ -402,12 +410,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||||
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
|
|
||||||
# if we are interrupted, sample returns just noise
|
|
||||||
# use the image collected previously in sampler loop
|
|
||||||
samples_ddim = shared.state.current_latent
|
|
||||||
|
|
||||||
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
||||||
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
@ -497,7 +499,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
||||||
def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs):
|
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.enable_hr = enable_hr
|
self.enable_hr = enable_hr
|
||||||
self.denoising_strength = denoising_strength
|
self.denoising_strength = denoising_strength
|
||||||
|
@ -538,17 +540,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||||
|
|
||||||
|
def create_dummy_mask(self, x, width=None, height=None):
|
||||||
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
|
height = height or self.height
|
||||||
|
width = width or self.width
|
||||||
|
|
||||||
|
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||||
|
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||||
|
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||||
|
|
||||||
|
# Add the fake full 1s mask to the first dimension.
|
||||||
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||||
|
image_conditioning = image_conditioning.to(x.dtype)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Dummy zero conditioning if we're not using inpainting model.
|
||||||
|
# Still takes up a bit of memory, but no encoder call.
|
||||||
|
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||||
|
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
return image_conditioning
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
|
||||||
|
|
||||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||||
|
|
||||||
|
@ -585,7 +607,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
x = None
|
x = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
@ -611,6 +633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
self.inpainting_mask_invert = inpainting_mask_invert
|
self.inpainting_mask_invert = inpainting_mask_invert
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
|
self.image_conditioning = None
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
||||||
|
@ -712,10 +735,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
elif self.inpainting_fill == 3:
|
elif self.inpainting_fill == 3:
|
||||||
self.init_latent = self.init_latent * self.mask
|
self.init_latent = self.init_latent * self.mask
|
||||||
|
|
||||||
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
|
if self.image_mask is not None:
|
||||||
|
conditioning_mask = np.array(self.image_mask.convert("L"))
|
||||||
|
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||||
|
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||||
|
|
||||||
|
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||||
|
conditioning_mask = torch.round(conditioning_mask)
|
||||||
|
else:
|
||||||
|
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
|
||||||
|
|
||||||
|
# Create another latent image, this time with a masked version of the original input.
|
||||||
|
conditioning_mask = conditioning_mask.to(image.device)
|
||||||
|
conditioning_image = image * (1.0 - conditioning_mask)
|
||||||
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||||
|
|
||||||
|
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||||
|
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
|
||||||
|
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
||||||
|
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
||||||
|
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
||||||
|
else:
|
||||||
|
self.image_conditioning = torch.zeros(
|
||||||
|
self.init_latent.shape[0], 5, 1, 1,
|
||||||
|
dtype=self.init_latent.dtype,
|
||||||
|
device=self.init_latent.device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
samples = samples * self.nmask + self.init_latent * self.mask
|
samples = samples * self.nmask + self.init_latent * self.mask
|
||||||
|
@ -723,4 +775,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
del x
|
del x
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
return samples
|
return samples
|
|
@ -275,7 +275,7 @@ re_attention = re.compile(r"""
|
||||||
|
|
||||||
def parse_prompt_attention(text):
|
def parse_prompt_attention(text):
|
||||||
"""
|
"""
|
||||||
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
|
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||||
Accepted tokens are:
|
Accepted tokens are:
|
||||||
(abc) - increases attention to abc by a multiplier of 1.1
|
(abc) - increases attention to abc by a multiplier of 1.1
|
||||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||||
|
|
|
@ -96,6 +96,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||||
class ScriptRunner:
|
class ScriptRunner:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scripts = []
|
self.scripts = []
|
||||||
|
self.titles = []
|
||||||
|
|
||||||
def setup_ui(self, is_img2img):
|
def setup_ui(self, is_img2img):
|
||||||
for script_class, path in scripts_data:
|
for script_class, path in scripts_data:
|
||||||
|
@ -107,9 +108,10 @@ class ScriptRunner:
|
||||||
|
|
||||||
self.scripts.append(script)
|
self.scripts.append(script)
|
||||||
|
|
||||||
titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
|
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
|
||||||
|
|
||||||
dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
|
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
|
||||||
|
dropdown.save_to_config = True
|
||||||
inputs = [dropdown]
|
inputs = [dropdown]
|
||||||
|
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
|
@ -139,6 +141,15 @@ class ScriptRunner:
|
||||||
|
|
||||||
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
|
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
|
||||||
|
|
||||||
|
def init_field(title):
|
||||||
|
if title == 'None':
|
||||||
|
return
|
||||||
|
script_index = self.titles.index(title)
|
||||||
|
script = self.scripts[script_index]
|
||||||
|
for i in range(script.args_from, script.args_to):
|
||||||
|
inputs[i].visible = True
|
||||||
|
|
||||||
|
dropdown.init_field = init_field
|
||||||
dropdown.change(
|
dropdown.change(
|
||||||
fn=select_script,
|
fn=select_script,
|
||||||
inputs=[dropdown],
|
inputs=[dropdown],
|
||||||
|
|
|
@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
|
|
||||||
|
|
||||||
def apply_optimizations():
|
def apply_optimizations():
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
|
|
||||||
|
@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
|
|
||||||
remade_tokens = remade_tokens[:last_comma]
|
remade_tokens = remade_tokens[:last_comma]
|
||||||
length = len(remade_tokens)
|
length = len(remade_tokens)
|
||||||
|
|
||||||
rem = int(math.ceil(length / 75)) * 75 - length
|
rem = int(math.ceil(length / 75)) * 75 - length
|
||||||
remade_tokens += [id_end] * rem + reloc_tokens
|
remade_tokens += [id_end] * rem + reloc_tokens
|
||||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||||
|
|
||||||
if embedding is None:
|
if embedding is None:
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
multipliers.append(weight)
|
multipliers.append(weight)
|
||||||
|
@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
|
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
|
|
||||||
def process_text_old(self, text):
|
def process_text_old(self, text):
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
|
@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
|
|
||||||
token_count = len(remade_tokens)
|
token_count = len(remade_tokens)
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||||
|
@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
hijack_fixes.append(fixes)
|
hijack_fixes.append(fixes)
|
||||||
batch_multipliers.append(multipliers)
|
batch_multipliers.append(multipliers)
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text):
|
||||||
use_old = opts.use_old_emphasis_implementation
|
use_old = opts.use_old_emphasis_implementation
|
||||||
if use_old:
|
if use_old:
|
||||||
|
@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
if len(used_custom_terms) > 0:
|
||||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
if use_old:
|
if use_old:
|
||||||
self.hijack.fixes = hijack_fixes
|
self.hijack.fixes = hijack_fixes
|
||||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||||
|
|
||||||
z = None
|
z = None
|
||||||
i = 0
|
i = 0
|
||||||
while max(map(len, remade_batch_tokens)) != 0:
|
while max(map(len, remade_batch_tokens)) != 0:
|
||||||
|
@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
if fix[0] == i:
|
if fix[0] == i:
|
||||||
fixes.append(fix[1])
|
fixes.append(fix[1])
|
||||||
self.hijack.fixes.append(fixes)
|
self.hijack.fixes.append(fixes)
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
multipliers = []
|
multipliers = []
|
||||||
for j in range(len(remade_batch_tokens)):
|
for j in range(len(remade_batch_tokens)):
|
||||||
|
@ -332,20 +332,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
multipliers.append([1.0] * 75)
|
multipliers.append([1.0] * 75)
|
||||||
|
|
||||||
z1 = self.process_tokens(tokens, multipliers)
|
z1 = self.process_tokens(tokens, multipliers)
|
||||||
|
z1 = shared.aesthetic_clip(z1, remade_batch_tokens)
|
||||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||||
|
|
||||||
remade_batch_tokens = rem_tokens
|
remade_batch_tokens = rem_tokens
|
||||||
batch_multipliers = rem_multipliers
|
batch_multipliers = rem_multipliers
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||||
if not opts.use_old_emphasis_implementation:
|
if not opts.use_old_emphasis_implementation:
|
||||||
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
||||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||||
|
|
||||||
|
@ -385,8 +385,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
for offset, embedding in fixes:
|
for offset, embedding in fixes:
|
||||||
emb = embedding.vec
|
emb = embedding.vec
|
||||||
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||||
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
|
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||||
|
|
||||||
vecs.append(tensor)
|
vecs.append(tensor)
|
||||||
|
|
||||||
|
|
331
modules/sd_hijack_inpainting.py
Normal file
331
modules/sd_hijack_inpainting.py
Normal file
|
@ -0,0 +1,331 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from einops import repeat
|
||||||
|
from omegaconf import ListConfig
|
||||||
|
|
||||||
|
import ldm.models.diffusion.ddpm
|
||||||
|
import ldm.models.diffusion.ddim
|
||||||
|
import ldm.models.diffusion.plms
|
||||||
|
|
||||||
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||||
|
|
||||||
|
# =================================================================================================
|
||||||
|
# Monkey patch DDIMSampler methods from RunwayML repo directly.
|
||||||
|
# Adapted from:
|
||||||
|
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
|
||||||
|
# =================================================================================================
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_ddim(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||||
|
|
||||||
|
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
if isinstance(c, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [
|
||||||
|
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||||
|
for i in range(len(c[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
|
||||||
|
# =================================================================================================
|
||||||
|
# Monkey patch PLMSSampler methods.
|
||||||
|
# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
|
||||||
|
# Adapted from:
|
||||||
|
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
|
||||||
|
# =================================================================================================
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_plms(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
|
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
def get_model_output(x, t):
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
|
||||||
|
if isinstance(c, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [
|
||||||
|
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||||
|
for i in range(len(c[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
e_t = get_model_output(x, t)
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = get_model_output(x_prev, t_next)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
elif len(old_eps) >= 3:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
return x_prev, pred_x0, e_t
|
||||||
|
|
||||||
|
# =================================================================================================
|
||||||
|
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
|
||||||
|
# Adapted from:
|
||||||
|
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
|
||||||
|
# =================================================================================================
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||||
|
if null_label is not None:
|
||||||
|
xc = null_label
|
||||||
|
if isinstance(xc, ListConfig):
|
||||||
|
xc = list(xc)
|
||||||
|
if isinstance(xc, dict) or isinstance(xc, list):
|
||||||
|
c = self.get_learned_conditioning(xc)
|
||||||
|
else:
|
||||||
|
if hasattr(xc, "to"):
|
||||||
|
xc = xc.to(self.device)
|
||||||
|
c = self.get_learned_conditioning(xc)
|
||||||
|
else:
|
||||||
|
# todo: get null label from cond_stage_model
|
||||||
|
raise NotImplementedError()
|
||||||
|
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class LatentInpaintDiffusion(LatentDiffusion):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
concat_keys=("mask", "masked_image"),
|
||||||
|
masked_image_key="masked_image",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.masked_image_key = masked_image_key
|
||||||
|
assert self.masked_image_key in concat_keys
|
||||||
|
self.concat_keys = concat_keys
|
||||||
|
|
||||||
|
|
||||||
|
def should_hijack_inpainting(checkpoint_info):
|
||||||
|
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def do_inpainting_hijack():
|
||||||
|
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||||
|
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
||||||
|
|
||||||
|
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||||
|
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||||
|
|
||||||
|
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||||
|
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
|
@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v):
|
||||||
mem_free_torch = mem_reserved - mem_active
|
mem_free_torch = mem_reserved - mem_active
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
# Divide factor of safety as there's copying and fragmentation
|
# Divide factor of safety as there's copying and fragmentation
|
||||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||||
|
|
||||||
def einsum_op(q, k, v):
|
def einsum_op(q, k, v):
|
||||||
if q.device.type == 'cuda':
|
if q.device.type == 'cuda':
|
||||||
|
@ -296,10 +296,16 @@ def xformers_attnblock_forward(self, x):
|
||||||
try:
|
try:
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
q1 = self.q(h_).contiguous()
|
q = self.q(h_)
|
||||||
k1 = self.k(h_).contiguous()
|
k = self.k(h_)
|
||||||
v = self.v(h_).contiguous()
|
v = self.v(h_)
|
||||||
out = xformers.ops.memory_efficient_attention(q1, k1, v)
|
b, c, h, w = q.shape
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v)
|
||||||
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x + out
|
return x + out
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared, modelloader, devices
|
from modules import shared, modelloader, devices
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
|
@ -20,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict()
|
||||||
try:
|
try:
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
|
|
||||||
from transformers import logging
|
from transformers import logging, CLIPModel
|
||||||
|
|
||||||
logging.set_verbosity_error()
|
logging.set_verbosity_error()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -122,9 +123,34 @@ def select_checkpoint():
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
|
chckpoint_dict_replacements = {
|
||||||
|
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
||||||
|
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
||||||
|
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def transform_checkpoint_dict_key(k):
|
||||||
|
for text, replacement in chckpoint_dict_replacements.items():
|
||||||
|
if k.startswith(text):
|
||||||
|
k = replacement + k[len(text):]
|
||||||
|
|
||||||
|
return k
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict_from_checkpoint(pl_sd):
|
def get_state_dict_from_checkpoint(pl_sd):
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
return pl_sd["state_dict"]
|
pl_sd = pl_sd["state_dict"]
|
||||||
|
|
||||||
|
sd = {}
|
||||||
|
for k, v in pl_sd.items():
|
||||||
|
new_key = transform_checkpoint_dict_key(k)
|
||||||
|
|
||||||
|
if new_key is not None:
|
||||||
|
sd[new_key] = v
|
||||||
|
|
||||||
|
pl_sd.clear()
|
||||||
|
pl_sd.update(sd)
|
||||||
|
|
||||||
return pl_sd
|
return pl_sd
|
||||||
|
|
||||||
|
@ -141,7 +167,7 @@ def load_model_weights(model, checkpoint_info):
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
|
||||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||||
model.load_state_dict(sd, strict=False)
|
missing, extra = model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
if shared.cmd_opts.opt_channelslast:
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
|
@ -178,14 +204,26 @@ def load_model_weights(model, checkpoint_info):
|
||||||
model.sd_checkpoint_info = checkpoint_info
|
model.sd_checkpoint_info = checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model(checkpoint_info=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
|
||||||
if checkpoint_info.config != shared.cmd_opts.config:
|
if checkpoint_info.config != shared.cmd_opts.config:
|
||||||
print(f"Loading config from: {checkpoint_info.config}")
|
print(f"Loading config from: {checkpoint_info.config}")
|
||||||
|
|
||||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||||
|
|
||||||
|
if should_hijack_inpainting(checkpoint_info):
|
||||||
|
# Hardcoded config for now...
|
||||||
|
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||||
|
sd_config.model.params.use_ema = False
|
||||||
|
sd_config.model.params.conditioning_key = "hybrid"
|
||||||
|
sd_config.model.params.unet_config.params.in_channels = 9
|
||||||
|
|
||||||
|
# Create a "fake" config with a different name so that we know to unload it when switching models.
|
||||||
|
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
||||||
|
|
||||||
|
do_inpainting_hijack()
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
|
@ -209,9 +247,9 @@ def reload_model_weights(sd_model, info=None):
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
|
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||||
checkpoints_loaded.clear()
|
checkpoints_loaded.clear()
|
||||||
shared.sd_model = load_model()
|
shared.sd_model = load_model(checkpoint_info)
|
||||||
return shared.sd_model
|
return shared.sd_model
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
|
|
@ -98,25 +98,8 @@ def store_latent(decoded):
|
||||||
shared.state.current_image = sample_to_image(decoded)
|
shared.state.current_image = sample_to_image(decoded)
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptedException(BaseException):
|
||||||
def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
pass
|
||||||
state.sampling_steps = len(sequence)
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
|
||||||
|
|
||||||
for x in seq:
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
break
|
|
||||||
|
|
||||||
yield x
|
|
||||||
|
|
||||||
state.sampling_step += 1
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
|
|
||||||
ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
|
||||||
ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class VanillaStableDiffusionSampler:
|
class VanillaStableDiffusionSampler:
|
||||||
|
@ -128,14 +111,40 @@ class VanillaStableDiffusionSampler:
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
self.stop_at = None
|
||||||
self.eta = None
|
self.eta = None
|
||||||
self.default_eta = 0.0
|
self.default_eta = 0.0
|
||||||
self.config = None
|
self.config = None
|
||||||
|
self.last_latent = None
|
||||||
|
|
||||||
|
self.conditioning_key = sd_model.model.conditioning_key
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
def number_of_needed_noises(self, p):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def launch_sampling(self, steps, func):
|
||||||
|
state.sampling_steps = steps
|
||||||
|
state.sampling_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except InterruptedException:
|
||||||
|
return self.last_latent
|
||||||
|
|
||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
|
if state.interrupted or state.skipped:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
if self.stop_at is not None and self.step > self.stop_at:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||||
|
image_conditioning = None
|
||||||
|
if isinstance(cond, dict):
|
||||||
|
image_conditioning = cond["c_concat"][0]
|
||||||
|
cond = cond["c_crossattn"][0]
|
||||||
|
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
|
|
||||||
|
@ -156,14 +165,25 @@ class VanillaStableDiffusionSampler:
|
||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||||
|
|
||||||
|
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||||
|
# Note that they need to be lists because it just concatenates them later.
|
||||||
|
if image_conditioning is not None:
|
||||||
|
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||||
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
store_latent(self.init_latent * self.mask + self.nmask * res[1])
|
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
||||||
else:
|
else:
|
||||||
store_latent(res[1])
|
self.last_latent = res[1]
|
||||||
|
|
||||||
|
store_latent(self.last_latent)
|
||||||
|
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
state.sampling_step = self.step
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def initialize(self, p):
|
def initialize(self, p):
|
||||||
|
@ -176,7 +196,7 @@ class VanillaStableDiffusionSampler:
|
||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
self.initialize(p)
|
self.initialize(p)
|
||||||
|
@ -190,25 +210,38 @@ class VanillaStableDiffusionSampler:
|
||||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||||
|
|
||||||
self.init_latent = x
|
self.init_latent = x
|
||||||
|
self.last_latent = x
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
|
if image_conditioning is not None:
|
||||||
|
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||||
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
|
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
self.initialize(p)
|
self.initialize(p)
|
||||||
|
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
self.last_latent = x
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
|
if image_conditioning is not None:
|
||||||
|
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||||
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
# existing code fails with certain step counts, like 9
|
# existing code fails with certain step counts, like 9
|
||||||
try:
|
try:
|
||||||
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||||
except Exception:
|
except Exception:
|
||||||
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||||
|
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
|
@ -222,7 +255,10 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||||
|
if state.interrupted or state.skipped:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
|
@ -230,28 +266,29 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1]:
|
if tensor.shape[1] == uncond.shape[1]:
|
||||||
cond_in = torch.cat([tensor, uncond])
|
cond_in = torch.cat([tensor, uncond])
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if shared.batch_cond_uncond:
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = a + batch_size
|
b = a + batch_size
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = min(a + batch_size, tensor.shape[0])
|
b = min(a + batch_size, tensor.shape[0])
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||||
|
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||||
|
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
denoised = torch.clone(denoised_uncond)
|
denoised = torch.clone(denoised_uncond)
|
||||||
|
@ -268,25 +305,6 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
|
||||||
def extended_trange(sampler, count, *args, **kwargs):
|
|
||||||
state.sampling_steps = count
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
|
||||||
|
|
||||||
for x in seq:
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
break
|
|
||||||
|
|
||||||
if sampler.stop_at is not None and x > sampler.stop_at:
|
|
||||||
break
|
|
||||||
|
|
||||||
yield x
|
|
||||||
|
|
||||||
state.sampling_step += 1
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
|
|
||||||
class TorchHijack:
|
class TorchHijack:
|
||||||
def __init__(self, kdiff_sampler):
|
def __init__(self, kdiff_sampler):
|
||||||
self.kdiff_sampler = kdiff_sampler
|
self.kdiff_sampler = kdiff_sampler
|
||||||
|
@ -314,9 +332,30 @@ class KDiffusionSampler:
|
||||||
self.eta = None
|
self.eta = None
|
||||||
self.default_eta = 1.0
|
self.default_eta = 1.0
|
||||||
self.config = None
|
self.config = None
|
||||||
|
self.last_latent = None
|
||||||
|
|
||||||
|
self.conditioning_key = sd_model.model.conditioning_key
|
||||||
|
|
||||||
def callback_state(self, d):
|
def callback_state(self, d):
|
||||||
store_latent(d["denoised"])
|
step = d['i']
|
||||||
|
latent = d["denoised"]
|
||||||
|
store_latent(latent)
|
||||||
|
self.last_latent = latent
|
||||||
|
|
||||||
|
if self.stop_at is not None and step > self.stop_at:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
state.sampling_step = step
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
|
def launch_sampling(self, steps, func):
|
||||||
|
state.sampling_steps = steps
|
||||||
|
state.sampling_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except InterruptedException:
|
||||||
|
return self.last_latent
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
def number_of_needed_noises(self, p):
|
||||||
return p.steps
|
return p.steps
|
||||||
|
@ -339,9 +378,6 @@ class KDiffusionSampler:
|
||||||
self.sampler_noise_index = 0
|
self.sampler_noise_index = 0
|
||||||
self.eta = p.eta or opts.eta_ancestral
|
self.eta = p.eta or opts.eta_ancestral
|
||||||
|
|
||||||
if hasattr(k_diffusion.sampling, 'trange'):
|
|
||||||
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
|
|
||||||
|
|
||||||
if self.sampler_noises is not None:
|
if self.sampler_noises is not None:
|
||||||
k_diffusion.sampling.torch = TorchHijack(self)
|
k_diffusion.sampling.torch = TorchHijack(self)
|
||||||
|
|
||||||
|
@ -355,7 +391,7 @@ class KDiffusionSampler:
|
||||||
|
|
||||||
return extra_params_kwargs
|
return extra_params_kwargs
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
if p.sampler_noise_scheduler_override:
|
||||||
|
@ -382,11 +418,18 @@ class KDiffusionSampler:
|
||||||
extra_params_kwargs['sigmas'] = sigma_sched
|
extra_params_kwargs['sigmas'] = sigma_sched
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
|
self.last_latent = x
|
||||||
|
|
||||||
return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale
|
||||||
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
if p.sampler_noise_scheduler_override:
|
||||||
|
@ -406,6 +449,14 @@ class KDiffusionSampler:
|
||||||
extra_params_kwargs['n'] = steps
|
extra_params_kwargs['n'] = steps
|
||||||
else:
|
else:
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
|
||||||
|
self.last_latent = x
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale
|
||||||
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -30,6 +31,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
|
||||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||||
|
parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(models_path, 'aesthetic_embeddings'), help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
|
||||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||||
|
@ -70,12 +72,14 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload
|
||||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||||
|
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||||
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
|
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
|
||||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||||
|
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
|
||||||
|
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
restricted_opts = [
|
restricted_opts = [
|
||||||
|
@ -104,6 +108,21 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||||
loaded_hypernetwork = None
|
loaded_hypernetwork = None
|
||||||
|
|
||||||
|
|
||||||
|
os.makedirs(cmd_opts.aesthetic_embeddings_dir, exist_ok=True)
|
||||||
|
aesthetic_embeddings = {}
|
||||||
|
|
||||||
|
|
||||||
|
def update_aesthetic_embeddings():
|
||||||
|
global aesthetic_embeddings
|
||||||
|
aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
|
||||||
|
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
|
||||||
|
aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
update_aesthetic_embeddings()
|
||||||
|
|
||||||
|
|
||||||
def reload_hypernetworks():
|
def reload_hypernetworks():
|
||||||
global hypernetworks
|
global hypernetworks
|
||||||
|
|
||||||
|
@ -135,7 +154,7 @@ class State:
|
||||||
self.job_no += 1
|
self.job_no += 1
|
||||||
self.sampling_step = 0
|
self.sampling_step = 0
|
||||||
self.current_image_sampling_step = 0
|
self.current_image_sampling_step = 0
|
||||||
|
|
||||||
def get_job_timestamp(self):
|
def get_job_timestamp(self):
|
||||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
||||||
|
|
||||||
|
@ -293,6 +312,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
|
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
|
||||||
|
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||||
"font": OptionInfo("", "Font for image grids that have text"),
|
"font": OptionInfo("", "Font for image grids that have text"),
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||||
|
@ -384,6 +404,11 @@ sd_upscalers = []
|
||||||
|
|
||||||
sd_model = None
|
sd_model = None
|
||||||
|
|
||||||
|
clip_model = None
|
||||||
|
|
||||||
|
from modules.aesthetic_clip import AestheticCLIP
|
||||||
|
aesthetic_clip = AestheticCLIP()
|
||||||
|
|
||||||
progress_print_out = sys.stdout
|
progress_print_out = sys.stdout
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ class StyleDatabase:
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(path, "r", encoding="utf8", newline='') as file:
|
with open(path, "r", encoding="utf-8-sig", newline='') as file:
|
||||||
reader = csv.DictReader(file)
|
reader = csv.DictReader(file)
|
||||||
for row in reader:
|
for row in reader:
|
||||||
# Support loading old CSV format with "name, text"-columns
|
# Support loading old CSV format with "name, text"-columns
|
||||||
|
@ -79,7 +79,7 @@ class StyleDatabase:
|
||||||
def save_styles(self, path: str) -> None:
|
def save_styles(self, path: str) -> None:
|
||||||
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
||||||
fd, temp_path = tempfile.mkstemp(".csv")
|
fd, temp_path = tempfile.mkstemp(".csv")
|
||||||
with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
|
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
||||||
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
||||||
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
||||||
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
||||||
|
|
|
@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
|
||||||
|
|
||||||
self.dataset.append(entry)
|
self.dataset.append(entry)
|
||||||
|
|
||||||
assert len(self.dataset) > 1, "No images have been found in the dataset."
|
assert len(self.dataset) > 0, "No images have been found in the dataset."
|
||||||
self.length = len(self.dataset) * repeats // batch_size
|
self.length = len(self.dataset) * repeats // batch_size
|
||||||
|
|
||||||
self.initial_indexes = np.arange(len(self.dataset))
|
self.initial_indexes = np.arange(len(self.dataset))
|
||||||
|
@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
|
||||||
self.shuffle()
|
self.shuffle()
|
||||||
|
|
||||||
def shuffle(self):
|
def shuffle(self):
|
||||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
|
||||||
|
|
||||||
def create_text(self, filename_text):
|
def create_text(self, filename_text):
|
||||||
text = random.choice(self.lines)
|
text = random.choice(self.lines)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import zlib
|
||||||
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
|
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
|
||||||
from fonts.ttf import Roboto
|
from fonts.ttf import Roboto
|
||||||
import torch
|
import torch
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingEncoder(json.JSONEncoder):
|
class EmbeddingEncoder(json.JSONEncoder):
|
||||||
|
@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
||||||
from math import cos
|
from math import cos
|
||||||
|
|
||||||
image = srcimage.copy()
|
image = srcimage.copy()
|
||||||
|
fontsize = 32
|
||||||
if textfont is None:
|
if textfont is None:
|
||||||
try:
|
try:
|
||||||
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
|
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
|
||||||
|
@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
||||||
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
|
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
|
||||||
|
|
||||||
draw = ImageDraw.Draw(image)
|
draw = ImageDraw.Draw(image)
|
||||||
fontsize = 32
|
|
||||||
font = ImageFont.truetype(textfont, fontsize)
|
font = ImageFont.truetype(textfont, fontsize)
|
||||||
padding = 10
|
padding = 10
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru:
|
||||||
import modules.deepbooru as deepbooru
|
import modules.deepbooru as deepbooru
|
||||||
|
|
||||||
|
|
||||||
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
||||||
try:
|
try:
|
||||||
if process_caption:
|
if process_caption:
|
||||||
shared.interrogator.load()
|
shared.interrogator.load()
|
||||||
|
@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||||
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
|
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
|
||||||
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
|
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
|
||||||
|
|
||||||
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
|
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
||||||
width = process_width
|
width = process_width
|
||||||
height = process_height
|
height = process_height
|
||||||
src = os.path.abspath(process_src)
|
src = os.path.abspath(process_src)
|
||||||
|
@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
||||||
shared.state.textinfo = "Preprocessing..."
|
shared.state.textinfo = "Preprocessing..."
|
||||||
shared.state.job_count = len(files)
|
shared.state.job_count = len(files)
|
||||||
|
|
||||||
def save_pic_with_caption(image, index):
|
def save_pic_with_caption(image, index, existing_caption=None):
|
||||||
caption = ""
|
caption = ""
|
||||||
|
|
||||||
if process_caption:
|
if process_caption:
|
||||||
|
@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
||||||
basename = f"{index:05}-{subindex[0]}-{filename_part}"
|
basename = f"{index:05}-{subindex[0]}-{filename_part}"
|
||||||
image.save(os.path.join(dst, f"{basename}.png"))
|
image.save(os.path.join(dst, f"{basename}.png"))
|
||||||
|
|
||||||
|
if preprocess_txt_action == 'prepend' and existing_caption:
|
||||||
|
caption = existing_caption + ' ' + caption
|
||||||
|
elif preprocess_txt_action == 'append' and existing_caption:
|
||||||
|
caption = caption + ' ' + existing_caption
|
||||||
|
elif preprocess_txt_action == 'copy' and existing_caption:
|
||||||
|
caption = existing_caption
|
||||||
|
|
||||||
|
caption = caption.strip()
|
||||||
|
|
||||||
if len(caption) > 0:
|
if len(caption) > 0:
|
||||||
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||||
file.write(caption)
|
file.write(caption)
|
||||||
|
|
||||||
subindex[0] += 1
|
subindex[0] += 1
|
||||||
|
|
||||||
def save_pic(image, index):
|
def save_pic(image, index, existing_caption=None):
|
||||||
save_pic_with_caption(image, index)
|
save_pic_with_caption(image, index, existing_caption=existing_caption)
|
||||||
|
|
||||||
if process_flip:
|
if process_flip:
|
||||||
save_pic_with_caption(ImageOps.mirror(image), index)
|
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
|
||||||
|
|
||||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||||
subindex = [0]
|
subindex = [0]
|
||||||
|
@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
existing_caption = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
||||||
img = img.resize((width, height * img.height // img.width))
|
img = img.resize((width, height * img.height // img.width))
|
||||||
|
|
||||||
top = img.crop((0, 0, width, height))
|
top = img.crop((0, 0, width, height))
|
||||||
save_pic(top, index)
|
save_pic(top, index, existing_caption=existing_caption)
|
||||||
|
|
||||||
bot = img.crop((0, img.height - height, width, img.height))
|
bot = img.crop((0, img.height - height, width, img.height))
|
||||||
save_pic(bot, index)
|
save_pic(bot, index, existing_caption=existing_caption)
|
||||||
elif process_split and is_wide:
|
elif process_split and is_wide:
|
||||||
img = img.resize((width * img.width // img.height, height))
|
img = img.resize((width * img.width // img.height, height))
|
||||||
|
|
||||||
left = img.crop((0, 0, width, height))
|
left = img.crop((0, 0, width, height))
|
||||||
save_pic(left, index)
|
save_pic(left, index, existing_caption=existing_caption)
|
||||||
|
|
||||||
right = img.crop((img.width - width, 0, img.width, height))
|
right = img.crop((img.width - width, 0, img.width, height))
|
||||||
save_pic(right, index)
|
save_pic(right, index, existing_caption=existing_caption)
|
||||||
else:
|
else:
|
||||||
img = images.resize_image(1, img, width, height)
|
img = images.resize_image(1, img, width, height)
|
||||||
save_pic(img, index)
|
save_pic(img, index, existing_caption=existing_caption)
|
||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
|
@ -153,7 +153,7 @@ class EmbeddingDatabase:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def create_embedding(name, num_vectors_per_token, init_text='*'):
|
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||||
cond_model = shared.sd_model.cond_stage_model
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||||
|
|
||||||
|
@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||||
|
|
||||||
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
if not overwrite_old:
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
embedding = Embedding(vec, name)
|
embedding = Embedding(vec, name)
|
||||||
embedding.step = 0
|
embedding.step = 0
|
||||||
|
@ -275,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
epoch_num = embedding.step // len(ds)
|
epoch_num = embedding.step // len(ds)
|
||||||
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
|
||||||
from modules import sd_hijack, shared
|
from modules import sd_hijack, shared
|
||||||
|
|
||||||
|
|
||||||
def create_embedding(name, initialization_text, nvpt):
|
def create_embedding(name, initialization_text, nvpt, overwrite_old):
|
||||||
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
|
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
||||||
|
StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
|
|
||||||
|
|
||||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
|
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
|
||||||
p = StableDiffusionProcessingTxt2Img(
|
p = StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||||
|
@ -35,6 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
||||||
firstphase_height=firstphase_height if enable_hr else None,
|
firstphase_height=firstphase_height if enable_hr else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
|
||||||
|
|
||||||
if cmd_opts.enable_console_prompts:
|
if cmd_opts.enable_console_prompts:
|
||||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
|
@ -53,4 +56,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
||||||
processed.images = []
|
processed.images = []
|
||||||
|
|
||||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||||
|
|
||||||
|
|
205
modules/ui.py
205
modules/ui.py
|
@ -12,7 +12,7 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import platform
|
import platform
|
||||||
import subprocess as sp
|
import subprocess as sp
|
||||||
from functools import reduce
|
from functools import partial, reduce
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -25,7 +25,9 @@ import gradio.routes
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, localization
|
from modules import sd_hijack, sd_models, localization
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.shared import opts, cmd_opts, restricted_opts
|
|
||||||
|
from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings
|
||||||
|
|
||||||
if cmd_opts.deepdanbooru:
|
if cmd_opts.deepdanbooru:
|
||||||
from modules.deepbooru import get_deepbooru_tags
|
from modules.deepbooru import get_deepbooru_tags
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
@ -41,8 +43,11 @@ from modules import prompt_parser
|
||||||
from modules.images import save_image
|
from modules.images import save_image
|
||||||
import modules.textual_inversion.ui
|
import modules.textual_inversion.ui
|
||||||
import modules.hypernetworks.ui
|
import modules.hypernetworks.ui
|
||||||
|
|
||||||
|
import modules.aesthetic_clip as aesthetic_clip
|
||||||
import modules.images_history as img_his
|
import modules.images_history as img_his
|
||||||
|
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
mimetypes.add_type('application/javascript', '.js')
|
mimetypes.add_type('application/javascript', '.js')
|
||||||
|
@ -261,6 +266,24 @@ def wrap_gradio_call(func, extra_outputs=None):
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def calc_time_left(progress, threshold, label, force_display):
|
||||||
|
if progress == 0:
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
time_since_start = time.time() - shared.state.time_start
|
||||||
|
eta = (time_since_start/progress)
|
||||||
|
eta_relative = eta-time_since_start
|
||||||
|
if (eta_relative > threshold and progress > 0.02) or force_display:
|
||||||
|
if eta_relative > 3600:
|
||||||
|
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
|
||||||
|
elif eta_relative > 60:
|
||||||
|
return label + time.strftime('%M:%S', time.gmtime(eta_relative))
|
||||||
|
else:
|
||||||
|
return label + time.strftime('%Ss', time.gmtime(eta_relative))
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def check_progress_call(id_part):
|
def check_progress_call(id_part):
|
||||||
if shared.state.job_count == 0:
|
if shared.state.job_count == 0:
|
||||||
return "", gr_show(False), gr_show(False), gr_show(False)
|
return "", gr_show(False), gr_show(False), gr_show(False)
|
||||||
|
@ -272,11 +295,15 @@ def check_progress_call(id_part):
|
||||||
if shared.state.sampling_steps > 0:
|
if shared.state.sampling_steps > 0:
|
||||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||||
|
|
||||||
|
time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
|
||||||
|
if time_left != "":
|
||||||
|
shared.state.time_left_force_display = True
|
||||||
|
|
||||||
progress = min(progress, 1)
|
progress = min(progress, 1)
|
||||||
|
|
||||||
progressbar = ""
|
progressbar = ""
|
||||||
if opts.show_progressbar:
|
if opts.show_progressbar:
|
||||||
progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>"""
|
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>"""
|
||||||
|
|
||||||
image = gr_show(False)
|
image = gr_show(False)
|
||||||
preview_visibility = gr_show(False)
|
preview_visibility = gr_show(False)
|
||||||
|
@ -308,6 +335,8 @@ def check_progress_call_initial(id_part):
|
||||||
shared.state.current_latent = None
|
shared.state.current_latent = None
|
||||||
shared.state.current_image = None
|
shared.state.current_image = None
|
||||||
shared.state.textinfo = None
|
shared.state.textinfo = None
|
||||||
|
shared.state.time_start = time.time()
|
||||||
|
shared.state.time_left_force_display = False
|
||||||
|
|
||||||
return check_progress_call(id_part)
|
return check_progress_call(id_part)
|
||||||
|
|
||||||
|
@ -458,14 +487,14 @@ def create_toprow(is_img2img):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
|
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
|
||||||
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
|
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
|
||||||
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -542,6 +571,10 @@ def apply_setting(key, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
return gr.update()
|
return gr.update()
|
||||||
|
|
||||||
|
# dont allow model to be swapped when model hash exists in prompt
|
||||||
|
if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
|
||||||
|
return gr.update()
|
||||||
|
|
||||||
if key == "sd_model_checkpoint":
|
if key == "sd_model_checkpoint":
|
||||||
ckpt_info = sd_models.get_closet_checkpoint_match(value)
|
ckpt_info = sd_models.get_closet_checkpoint_match(value)
|
||||||
|
|
||||||
|
@ -564,27 +597,29 @@ def apply_setting(key, value):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
|
def refresh():
|
||||||
|
refresh_method()
|
||||||
|
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
||||||
|
|
||||||
|
for k, v in args.items():
|
||||||
|
setattr(refresh_component, k, v)
|
||||||
|
|
||||||
|
return gr.update(**(args or {}))
|
||||||
|
|
||||||
|
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
|
||||||
|
refresh_button.click(
|
||||||
|
fn=refresh,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[refresh_component]
|
||||||
|
)
|
||||||
|
return refresh_button
|
||||||
|
|
||||||
|
|
||||||
def create_ui(wrap_gradio_gpu_call):
|
def create_ui(wrap_gradio_gpu_call):
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
|
|
||||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
|
||||||
def refresh():
|
|
||||||
refresh_method()
|
|
||||||
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
|
||||||
|
|
||||||
for k, v in args.items():
|
|
||||||
setattr(refresh_component, k, v)
|
|
||||||
|
|
||||||
return gr.update(**(args or {}))
|
|
||||||
|
|
||||||
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
|
|
||||||
refresh_button.click(
|
|
||||||
fn = refresh,
|
|
||||||
inputs = [],
|
|
||||||
outputs = [refresh_component]
|
|
||||||
)
|
|
||||||
return refresh_button
|
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||||
|
@ -627,6 +662,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
|
|
||||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
||||||
|
|
||||||
|
aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui()
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
|
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
|
||||||
|
|
||||||
|
@ -681,7 +718,16 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
denoising_strength,
|
denoising_strength,
|
||||||
firstphase_width,
|
firstphase_width,
|
||||||
firstphase_height,
|
firstphase_height,
|
||||||
|
aesthetic_lr,
|
||||||
|
aesthetic_weight,
|
||||||
|
aesthetic_steps,
|
||||||
|
aesthetic_imgs,
|
||||||
|
aesthetic_slerp,
|
||||||
|
aesthetic_imgs_text,
|
||||||
|
aesthetic_slerp_angle,
|
||||||
|
aesthetic_text_negative
|
||||||
] + custom_inputs,
|
] + custom_inputs,
|
||||||
|
|
||||||
outputs=[
|
outputs=[
|
||||||
txt2img_gallery,
|
txt2img_gallery,
|
||||||
generation_info,
|
generation_info,
|
||||||
|
@ -758,6 +804,14 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||||
(firstphase_width, "First pass size-1"),
|
(firstphase_width, "First pass size-1"),
|
||||||
(firstphase_height, "First pass size-2"),
|
(firstphase_height, "First pass size-2"),
|
||||||
|
(aesthetic_lr, "Aesthetic LR"),
|
||||||
|
(aesthetic_weight, "Aesthetic weight"),
|
||||||
|
(aesthetic_steps, "Aesthetic steps"),
|
||||||
|
(aesthetic_imgs, "Aesthetic embedding"),
|
||||||
|
(aesthetic_slerp, "Aesthetic slerp"),
|
||||||
|
(aesthetic_imgs_text, "Aesthetic text"),
|
||||||
|
(aesthetic_text_negative, "Aesthetic text negative"),
|
||||||
|
(aesthetic_slerp_angle, "Aesthetic slerp angle"),
|
||||||
]
|
]
|
||||||
|
|
||||||
txt2img_preview_params = [
|
txt2img_preview_params = [
|
||||||
|
@ -842,6 +896,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
|
|
||||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
||||||
|
|
||||||
|
aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui()
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
|
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
|
||||||
|
|
||||||
|
@ -932,6 +988,14 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
inpainting_mask_invert,
|
inpainting_mask_invert,
|
||||||
img2img_batch_input_dir,
|
img2img_batch_input_dir,
|
||||||
img2img_batch_output_dir,
|
img2img_batch_output_dir,
|
||||||
|
aesthetic_lr_im,
|
||||||
|
aesthetic_weight_im,
|
||||||
|
aesthetic_steps_im,
|
||||||
|
aesthetic_imgs_im,
|
||||||
|
aesthetic_slerp_im,
|
||||||
|
aesthetic_imgs_text_im,
|
||||||
|
aesthetic_slerp_angle_im,
|
||||||
|
aesthetic_text_negative_im,
|
||||||
] + custom_inputs,
|
] + custom_inputs,
|
||||||
outputs=[
|
outputs=[
|
||||||
img2img_gallery,
|
img2img_gallery,
|
||||||
|
@ -1023,6 +1087,14 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
(seed_resize_from_w, "Seed resize from-1"),
|
(seed_resize_from_w, "Seed resize from-1"),
|
||||||
(seed_resize_from_h, "Seed resize from-2"),
|
(seed_resize_from_h, "Seed resize from-2"),
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
|
(aesthetic_lr_im, "Aesthetic LR"),
|
||||||
|
(aesthetic_weight_im, "Aesthetic weight"),
|
||||||
|
(aesthetic_steps_im, "Aesthetic steps"),
|
||||||
|
(aesthetic_imgs_im, "Aesthetic embedding"),
|
||||||
|
(aesthetic_slerp_im, "Aesthetic slerp"),
|
||||||
|
(aesthetic_imgs_text_im, "Aesthetic text"),
|
||||||
|
(aesthetic_text_negative_im, "Aesthetic text negative"),
|
||||||
|
(aesthetic_slerp_angle_im, "Aesthetic slerp angle"),
|
||||||
]
|
]
|
||||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||||
|
|
||||||
|
@ -1183,6 +1255,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_embedding_name = gr.Textbox(label="Name")
|
new_embedding_name = gr.Textbox(label="Name")
|
||||||
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
||||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||||
|
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
|
@ -1191,9 +1264,25 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
create_embedding = gr.Button(value="Create embedding", variant='primary')
|
create_embedding = gr.Button(value="Create embedding", variant='primary')
|
||||||
|
|
||||||
|
with gr.Tab(label="Create aesthetic images embedding"):
|
||||||
|
|
||||||
|
new_embedding_name_ae = gr.Textbox(label="Name")
|
||||||
|
process_src_ae = gr.Textbox(label='Source directory')
|
||||||
|
batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
gr.HTML(value="")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')
|
||||||
|
|
||||||
with gr.Tab(label="Create hypernetwork"):
|
with gr.Tab(label="Create hypernetwork"):
|
||||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||||
|
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||||
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||||
|
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||||
|
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
|
@ -1207,6 +1296,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
process_dst = gr.Textbox(label='Destination directory')
|
process_dst = gr.Textbox(label='Destination directory')
|
||||||
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
|
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
process_flip = gr.Checkbox(label='Create flipped copies')
|
process_flip = gr.Checkbox(label='Create flipped copies')
|
||||||
|
@ -1222,14 +1312,17 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||||
|
|
||||||
with gr.Tab(label="Train"):
|
with gr.Tab(label="Train"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||||
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
with gr.Row():
|
||||||
|
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
|
||||||
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
|
||||||
|
|
||||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
|
@ -1263,6 +1356,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_embedding_name,
|
new_embedding_name,
|
||||||
initialization_text,
|
initialization_text,
|
||||||
nvpt,
|
nvpt,
|
||||||
|
overwrite_old_embedding,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_embedding_name,
|
train_embedding_name,
|
||||||
|
@ -1271,11 +1365,30 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
create_embedding_ae.click(
|
||||||
|
fn=aesthetic_clip.generate_imgs_embd,
|
||||||
|
inputs=[
|
||||||
|
new_embedding_name_ae,
|
||||||
|
process_src_ae,
|
||||||
|
batch_ae
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
aesthetic_imgs,
|
||||||
|
aesthetic_imgs_im,
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
create_hypernetwork.click(
|
create_hypernetwork.click(
|
||||||
fn=modules.hypernetworks.ui.create_hypernetwork,
|
fn=modules.hypernetworks.ui.create_hypernetwork,
|
||||||
inputs=[
|
inputs=[
|
||||||
new_hypernetwork_name,
|
new_hypernetwork_name,
|
||||||
new_hypernetwork_sizes,
|
new_hypernetwork_sizes,
|
||||||
|
overwrite_old_hypernetwork,
|
||||||
|
new_hypernetwork_layer_structure,
|
||||||
|
new_hypernetwork_add_layer_norm,
|
||||||
|
new_hypernetwork_activation_func,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
|
@ -1292,6 +1405,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
process_dst,
|
process_dst,
|
||||||
process_width,
|
process_width,
|
||||||
process_height,
|
process_height,
|
||||||
|
preprocess_txt_action,
|
||||||
process_flip,
|
process_flip,
|
||||||
process_split,
|
process_split,
|
||||||
process_caption,
|
process_caption,
|
||||||
|
@ -1308,7 +1422,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
inputs=[
|
inputs=[
|
||||||
train_embedding_name,
|
train_embedding_name,
|
||||||
learn_rate,
|
embedding_learn_rate,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
|
@ -1333,10 +1447,12 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
inputs=[
|
inputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
learn_rate,
|
hypernetwork_learn_rate,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
|
training_width,
|
||||||
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
|
@ -1533,6 +1649,7 @@ Requested path was: {f}
|
||||||
|
|
||||||
def reload_scripts():
|
def reload_scripts():
|
||||||
modules.scripts.reload_script_body_only()
|
modules.scripts.reload_script_body_only()
|
||||||
|
reload_javascript() # need to refresh the html page
|
||||||
|
|
||||||
reload_script_bodies.click(
|
reload_script_bodies.click(
|
||||||
fn=reload_scripts,
|
fn=reload_scripts,
|
||||||
|
@ -1733,7 +1850,7 @@ Requested path was: {f}
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
def loadsave(path, x):
|
def loadsave(path, x):
|
||||||
def apply_field(obj, field, condition=None):
|
def apply_field(obj, field, condition=None, init_field=None):
|
||||||
key = path + "/" + field
|
key = path + "/" + field
|
||||||
|
|
||||||
if getattr(obj,'custom_script_source',None) is not None:
|
if getattr(obj,'custom_script_source',None) is not None:
|
||||||
|
@ -1749,6 +1866,8 @@ Requested path was: {f}
|
||||||
print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
|
print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
|
||||||
else:
|
else:
|
||||||
setattr(obj, field, saved_value)
|
setattr(obj, field, saved_value)
|
||||||
|
if init_field is not None:
|
||||||
|
init_field(saved_value)
|
||||||
|
|
||||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
|
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
|
||||||
apply_field(x, 'visible')
|
apply_field(x, 'visible')
|
||||||
|
@ -1774,7 +1893,8 @@ Requested path was: {f}
|
||||||
# Since there are many dropdowns that shouldn't be saved,
|
# Since there are many dropdowns that shouldn't be saved,
|
||||||
# we only mark dropdowns that should be saved.
|
# we only mark dropdowns that should be saved.
|
||||||
if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
|
if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
|
||||||
apply_field(x, 'value', lambda val: val in x.choices)
|
apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
|
||||||
|
apply_field(x, 'visible')
|
||||||
|
|
||||||
visit(txt2img_interface, loadsave, "txt2img")
|
visit(txt2img_interface, loadsave, "txt2img")
|
||||||
visit(img2img_interface, loadsave, "img2img")
|
visit(img2img_interface, loadsave, "img2img")
|
||||||
|
@ -1788,23 +1908,30 @@ Requested path was: {f}
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
|
||||||
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
|
def load_javascript(raw_response):
|
||||||
javascript = f'<script>{jsfile.read()}</script>'
|
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
|
||||||
|
javascript = f'<script>{jsfile.read()}</script>'
|
||||||
|
|
||||||
jsdir = os.path.join(script_path, "javascript")
|
jsdir = os.path.join(script_path, "javascript")
|
||||||
for filename in sorted(os.listdir(jsdir)):
|
for filename in sorted(os.listdir(jsdir)):
|
||||||
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
|
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
|
||||||
javascript += f"\n<script>{jsfile.read()}</script>"
|
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
|
||||||
|
|
||||||
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
|
if cmd_opts.theme is not None:
|
||||||
|
javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
|
||||||
|
|
||||||
|
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
|
||||||
|
|
||||||
if 'gradio_routes_templates_response' not in globals():
|
|
||||||
def template_response(*args, **kwargs):
|
def template_response(*args, **kwargs):
|
||||||
res = gradio_routes_templates_response(*args, **kwargs)
|
res = raw_response(*args, **kwargs)
|
||||||
res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
|
res.body = res.body.replace(
|
||||||
|
b'</head>', f'{javascript}</head>'.encode("utf8"))
|
||||||
res.init_headers()
|
res.init_headers()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
|
|
||||||
gradio.routes.templates.TemplateResponse = template_response
|
gradio.routes.templates.TemplateResponse = template_response
|
||||||
|
|
||||||
|
|
||||||
|
reload_javascript = partial(load_javascript,
|
||||||
|
gradio.routes.templates.TemplateResponse)
|
||||||
|
reload_javascript()
|
||||||
|
|
|
@ -23,3 +23,4 @@ resize-right
|
||||||
torchdiffeq
|
torchdiffeq
|
||||||
kornia
|
kornia
|
||||||
lark
|
lark
|
||||||
|
inflection
|
||||||
|
|
|
@ -22,3 +22,4 @@ resize-right==0.0.2
|
||||||
torchdiffeq==0.2.3
|
torchdiffeq==0.2.3
|
||||||
kornia==0.6.7
|
kornia==0.6.7
|
||||||
lark==1.1.2
|
lark==1.1.2
|
||||||
|
inflection==0.5.1
|
||||||
|
|
|
@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs):
|
||||||
if info is None:
|
if info is None:
|
||||||
raise RuntimeError(f"Unknown checkpoint: {x}")
|
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||||
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
||||||
|
p.sd_model = shared.sd_model
|
||||||
|
|
||||||
|
|
||||||
def confirm_checkpoints(p, xs):
|
def confirm_checkpoints(p, xs):
|
||||||
|
|
11
style.css
11
style.css
|
@ -34,9 +34,10 @@
|
||||||
.performance {
|
.performance {
|
||||||
font-size: 0.85em;
|
font-size: 0.85em;
|
||||||
color: #444;
|
color: #444;
|
||||||
display: flex;
|
}
|
||||||
justify-content: space-between;
|
|
||||||
white-space: nowrap;
|
.performance p{
|
||||||
|
display: inline-block;
|
||||||
}
|
}
|
||||||
|
|
||||||
.performance .time {
|
.performance .time {
|
||||||
|
@ -44,8 +45,6 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
.performance .vram {
|
.performance .vram {
|
||||||
margin-left: 0;
|
|
||||||
text-align: right;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_generate, #img2img_generate {
|
#txt2img_generate, #img2img_generate {
|
||||||
|
@ -478,7 +477,7 @@ input[type="range"]{
|
||||||
padding: 0;
|
padding: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
|
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_aesthetic_embeddings{
|
||||||
max-width: 2.5em;
|
max-width: 2.5em;
|
||||||
min-width: 2.5em;
|
min-width: 2.5em;
|
||||||
height: 2.4em;
|
height: 2.4em;
|
||||||
|
|
|
@ -33,7 +33,7 @@ goto :launch
|
||||||
:skip_venv
|
:skip_venv
|
||||||
|
|
||||||
:launch
|
:launch
|
||||||
%PYTHON% launch.py
|
%PYTHON% launch.py %*
|
||||||
pause
|
pause
|
||||||
exit /b
|
exit /b
|
||||||
|
|
||||||
|
|
59
webui.py
59
webui.py
|
@ -4,7 +4,7 @@ import time
|
||||||
import importlib
|
import importlib
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
|
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
@ -31,7 +31,6 @@ from modules.paths import script_path
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
import modules.hypernetworks.hypernetwork
|
import modules.hypernetworks.hypernetwork
|
||||||
|
|
||||||
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,10 +86,6 @@ def initialize():
|
||||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
||||||
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
||||||
|
|
||||||
|
|
||||||
def webui():
|
|
||||||
initialize()
|
|
||||||
|
|
||||||
# make the program just exit at ctrl+c without waiting for anything
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
def sigint_handler(sig, frame):
|
def sigint_handler(sig, frame):
|
||||||
print(f'Interrupted with signal {sig} in {frame}')
|
print(f'Interrupted with signal {sig} in {frame}')
|
||||||
|
@ -98,10 +93,38 @@ def webui():
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, sigint_handler)
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
while 1:
|
|
||||||
|
|
||||||
|
def create_api(app):
|
||||||
|
from modules.api.api import Api
|
||||||
|
api = Api(app, queue_lock)
|
||||||
|
return api
|
||||||
|
|
||||||
|
def wait_on_server(demo=None):
|
||||||
|
while 1:
|
||||||
|
time.sleep(0.5)
|
||||||
|
if demo and getattr(demo, 'do_restart', False):
|
||||||
|
time.sleep(0.5)
|
||||||
|
demo.close()
|
||||||
|
time.sleep(0.5)
|
||||||
|
break
|
||||||
|
|
||||||
|
def api_only():
|
||||||
|
initialize()
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
|
api = create_api(app)
|
||||||
|
|
||||||
|
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
||||||
|
|
||||||
|
|
||||||
|
def webui():
|
||||||
|
launch_api = cmd_opts.api
|
||||||
|
initialize()
|
||||||
|
|
||||||
|
while 1:
|
||||||
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
||||||
|
|
||||||
app, local_url, share_url = demo.launch(
|
app, local_url, share_url = demo.launch(
|
||||||
share=cmd_opts.share,
|
share=cmd_opts.share,
|
||||||
server_name="0.0.0.0" if cmd_opts.listen else None,
|
server_name="0.0.0.0" if cmd_opts.listen else None,
|
||||||
|
@ -111,16 +134,13 @@ def webui():
|
||||||
inbrowser=cmd_opts.autolaunch,
|
inbrowser=cmd_opts.autolaunch,
|
||||||
prevent_thread_lock=True
|
prevent_thread_lock=True
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
|
|
||||||
while 1:
|
if (launch_api):
|
||||||
time.sleep(0.5)
|
create_api(app)
|
||||||
if getattr(demo, 'do_restart', False):
|
|
||||||
time.sleep(0.5)
|
wait_on_server(demo)
|
||||||
demo.close()
|
|
||||||
time.sleep(0.5)
|
|
||||||
break
|
|
||||||
|
|
||||||
sd_samplers.set_samplers()
|
sd_samplers.set_samplers()
|
||||||
|
|
||||||
|
@ -133,5 +153,10 @@ def webui():
|
||||||
print('Restarting Gradio')
|
print('Restarting Gradio')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
task = []
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
webui()
|
if cmd_opts.nowebui:
|
||||||
|
api_only()
|
||||||
|
else:
|
||||||
|
webui()
|
||||||
|
|
2
webui.sh
2
webui.sh
|
@ -138,4 +138,4 @@ fi
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
printf "Launching launch.py..."
|
printf "Launching launch.py..."
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
"${python_cmd}" "${LAUNCH_SCRIPT}"
|
"${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user