Merge branch 'master' into 6031_add_vae_generation_params and updated code slightly to fix merge conflict
This commit is contained in:
commit
b55127bbd4
29
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
29
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
|
@ -37,20 +37,20 @@ body:
|
|||
id: what-should
|
||||
attributes:
|
||||
label: What should have happened?
|
||||
description: tell what you think the normal behavior should be
|
||||
description: Tell what you think the normal behavior should be
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: commit
|
||||
attributes:
|
||||
label: Commit where the problem happens
|
||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
id: platforms
|
||||
attributes:
|
||||
label: What platforms do you use to access UI ?
|
||||
label: What platforms do you use to access the UI ?
|
||||
multiple: true
|
||||
options:
|
||||
- Windows
|
||||
|
@ -74,10 +74,27 @@ body:
|
|||
id: cmdargs
|
||||
attributes:
|
||||
label: Command Line Arguments
|
||||
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
|
||||
description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
|
||||
render: Shell
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: extensions
|
||||
attributes:
|
||||
label: List of extensions
|
||||
description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Console logs
|
||||
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
|
||||
render: Shell
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: misc
|
||||
attributes:
|
||||
label: Additional information, context and logs
|
||||
description: Please provide us with any relevant additional info, context or log output.
|
||||
label: Additional information
|
||||
description: Please provide us with any relevant additional info or context.
|
||||
|
|
|
@ -18,8 +18,8 @@ More technical discussion about your changes go here, plus anything that a maint
|
|||
|
||||
List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
|
||||
- OS: [e.g. Windows, Linux]
|
||||
- Browser [e.g. chrome, safari]
|
||||
- Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
|
||||
- Browser: [e.g. chrome, safari]
|
||||
- Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
|
||||
|
||||
**Screenshots or videos of your changes**
|
||||
|
13
.github/workflows/on_pull_request.yaml
vendored
13
.github/workflows/on_pull_request.yaml
vendored
|
@ -19,22 +19,19 @@ jobs:
|
|||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v3
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.10.6
|
||||
- uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
cache: pip
|
||||
cache-dependency-path: |
|
||||
**/requirements*txt
|
||||
- name: Install PyLint
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pylint
|
||||
# This lets PyLint check to see if it can resolve imports
|
||||
- name: Install dependencies
|
||||
run : |
|
||||
run: |
|
||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
||||
python launch.py
|
||||
- name: Analysing the code with pylint
|
||||
|
|
10
.github/workflows/run_tests.yaml
vendored
10
.github/workflows/run_tests.yaml
vendored
|
@ -14,13 +14,11 @@ jobs:
|
|||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.10.6
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
cache: pip
|
||||
cache-dependency-path: |
|
||||
**/requirements*txt
|
||||
- name: Run tests
|
||||
run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||
run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||
- name: Upload main app stdout-stderr
|
||||
uses: actions/upload-artifact@v3
|
||||
if: always()
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -32,3 +32,4 @@ notification.mp3
|
|||
/extensions
|
||||
/test/stdout.txt
|
||||
/test/stderr.txt
|
||||
/cache.json
|
||||
|
|
663
LICENSE.txt
Normal file
663
LICENSE.txt
Normal file
|
@ -0,0 +1,663 @@
|
|||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (c) 2023 AUTOMATIC1111
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
19
README.md
19
README.md
|
@ -17,7 +17,7 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||
- a man in a (tuxedo:1.21) - alternative syntax
|
||||
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
|
||||
- Loopback, run img2img processing multiple times
|
||||
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
|
||||
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
|
||||
- Textual Inversion
|
||||
- have as many embeddings as you want and use any names you like for them
|
||||
- use multiple embeddings with different numbers of vectors per token
|
||||
|
@ -49,9 +49,9 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
||||
- Mouseover hints for most UI elements
|
||||
- Possible to change defaults/mix/max/step values for UI elements via text config
|
||||
- Random artist button
|
||||
- Tiling support, a checkbox to create images that can be tiled like textures
|
||||
- Progress bar and live image generation preview
|
||||
- Can use a separate neural network to produce previews with almost none VRAM or compute requirement
|
||||
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
|
||||
- Styles, a way to save part of prompt and easily apply them via dropdown later
|
||||
- Variations, a way to generate same image but with tiny differences
|
||||
|
@ -76,13 +76,22 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||
- hypernetworks and embeddings options
|
||||
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
|
||||
- Clip skip
|
||||
- Use Hypernetworks
|
||||
- Use VAEs
|
||||
- Hypernetworks
|
||||
- Loras (same as Hypernetworks but more pretty)
|
||||
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt.
|
||||
- Can select to load a different VAE from settings screen
|
||||
- Estimated completion time in progress bar
|
||||
- API
|
||||
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
||||
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
||||
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
|
||||
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
|
||||
- Now without any bad letters!
|
||||
- Load checkpoints in safetensors format
|
||||
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
|
||||
- Now with a license!
|
||||
- Reorder elements in the UI from settings screen
|
||||
-
|
||||
|
||||
## Installation and Running
|
||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||
|
@ -146,6 +155,8 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
|
||||
- xformers - https://github.com/facebookresearch/xformers
|
||||
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
|
||||
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
|
||||
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
|
||||
- Security advice - RyotaK
|
||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||
- (You)
|
||||
|
|
3041
artists.csv
3041
artists.csv
File diff suppressed because it is too large
Load Diff
99
configs/instruct-pix2pix.yaml
Normal file
99
configs/instruct-pix2pix.yaml
Normal file
|
@ -0,0 +1,99 @@
|
|||
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
||||
# See more details in LICENSE.
|
||||
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: edited
|
||||
cond_stage_key: edit
|
||||
# image_size: 64
|
||||
# image_size: 32
|
||||
image_size: 16
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: true
|
||||
load_ema: true
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 0 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 128
|
||||
num_workers: 1
|
||||
wrap: false
|
||||
validation:
|
||||
target: edit_dataset.EditDataset
|
||||
params:
|
||||
path: data/clip-filtered-dataset
|
||||
cache_dir: data/
|
||||
cache_name: data_10k
|
||||
split: val
|
||||
min_text_sim: 0.2
|
||||
min_image_sim: 0.75
|
||||
min_direction_sim: 0.2
|
||||
max_samples_per_prompt: 1
|
||||
min_resize_res: 512
|
||||
max_resize_res: 512
|
||||
crop_res: 512
|
||||
output_as_edit: False
|
||||
real_input: True
|
|
@ -1,8 +1,7 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
base_learning_rate: 7.5e-05
|
||||
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
|
@ -12,29 +11,36 @@ model:
|
|||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid # important
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
finetune_keys: null
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
|
@ -43,7 +49,6 @@ model:
|
|||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
|
@ -62,7 +67,4 @@ model:
|
|||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
import gc
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
|||
from ldm.util import instantiate_from_config, ismap
|
||||
from modules import shared, sd_hijack
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
cached_ldsr_model: torch.nn.Module = None
|
||||
|
||||
|
||||
|
|
26
extensions-builtin/Lora/extra_networks_lora.py
Normal file
26
extensions-builtin/Lora/extra_networks_lora.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from modules import extra_networks, shared
|
||||
import lora
|
||||
|
||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
def __init__(self):
|
||||
super().__init__('lora')
|
||||
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_lora
|
||||
|
||||
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
|
||||
names = []
|
||||
multipliers = []
|
||||
for params in params_list:
|
||||
assert len(params.items) > 0
|
||||
|
||||
names.append(params.items[0])
|
||||
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
||||
|
||||
lora.load_loras(names, multipliers)
|
||||
|
||||
def deactivate(self, p):
|
||||
pass
|
207
extensions-builtin/Lora/lora.py
Normal file
207
extensions-builtin/Lora/lora.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
import glob
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
|
||||
from modules import shared, devices, sd_models
|
||||
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
||||
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
|
||||
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
|
||||
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
||||
|
||||
|
||||
def convert_diffusers_name_to_compvis(key):
|
||||
def match(match_list, regex):
|
||||
r = re.match(regex, key)
|
||||
if not r:
|
||||
return False
|
||||
|
||||
match_list.clear()
|
||||
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
||||
return True
|
||||
|
||||
m = []
|
||||
|
||||
if match(m, re_unet_down_blocks):
|
||||
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
|
||||
|
||||
if match(m, re_unet_mid_blocks):
|
||||
return f"diffusion_model_middle_block_1_{m[1]}"
|
||||
|
||||
if match(m, re_unet_up_blocks):
|
||||
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
||||
|
||||
if match(m, re_text_block):
|
||||
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||
|
||||
return key
|
||||
|
||||
|
||||
class LoraOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
self.name = name
|
||||
self.filename = filename
|
||||
|
||||
|
||||
class LoraModule:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.multiplier = 1.0
|
||||
self.modules = {}
|
||||
self.mtime = None
|
||||
|
||||
|
||||
class LoraUpDownModule:
|
||||
def __init__(self):
|
||||
self.up = None
|
||||
self.down = None
|
||||
self.alpha = None
|
||||
|
||||
|
||||
def assign_lora_names_to_compvis_modules(sd_model):
|
||||
lora_layer_mapping = {}
|
||||
|
||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||
lora_name = name.replace(".", "_")
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
|
||||
for name, module in shared.sd_model.model.named_modules():
|
||||
lora_name = name.replace(".", "_")
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
|
||||
sd_model.lora_layer_mapping = lora_layer_mapping
|
||||
|
||||
|
||||
def load_lora(name, filename):
|
||||
lora = LoraModule(name)
|
||||
lora.mtime = os.path.getmtime(filename)
|
||||
|
||||
sd = sd_models.read_state_dict(filename)
|
||||
|
||||
keys_failed_to_match = []
|
||||
|
||||
for key_diffusers, weight in sd.items():
|
||||
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
|
||||
key, lora_key = fullkey.split(".", 1)
|
||||
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||
if sd_module is None:
|
||||
keys_failed_to_match.append(key_diffusers)
|
||||
continue
|
||||
|
||||
lora_module = lora.modules.get(key, None)
|
||||
if lora_module is None:
|
||||
lora_module = LoraUpDownModule()
|
||||
lora.modules[key] = lora_module
|
||||
|
||||
if lora_key == "alpha":
|
||||
lora_module.alpha = weight.item()
|
||||
continue
|
||||
|
||||
if type(sd_module) == torch.nn.Linear:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(sd_module) == torch.nn.Conv2d:
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||
else:
|
||||
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
||||
|
||||
with torch.no_grad():
|
||||
module.weight.copy_(weight)
|
||||
|
||||
module.to(device=devices.device, dtype=devices.dtype)
|
||||
|
||||
if lora_key == "lora_up.weight":
|
||||
lora_module.up = module
|
||||
elif lora_key == "lora_down.weight":
|
||||
lora_module.down = module
|
||||
else:
|
||||
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
||||
|
||||
if len(keys_failed_to_match) > 0:
|
||||
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
||||
|
||||
return lora
|
||||
|
||||
|
||||
def load_loras(names, multipliers=None):
|
||||
already_loaded = {}
|
||||
|
||||
for lora in loaded_loras:
|
||||
if lora.name in names:
|
||||
already_loaded[lora.name] = lora
|
||||
|
||||
loaded_loras.clear()
|
||||
|
||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
||||
if any([x is None for x in loras_on_disk]):
|
||||
list_available_loras()
|
||||
|
||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
||||
|
||||
for i, name in enumerate(names):
|
||||
lora = already_loaded.get(name, None)
|
||||
|
||||
lora_on_disk = loras_on_disk[i]
|
||||
if lora_on_disk is not None:
|
||||
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
||||
lora = load_lora(name, lora_on_disk.filename)
|
||||
|
||||
if lora is None:
|
||||
print(f"Couldn't find Lora with name {name}")
|
||||
continue
|
||||
|
||||
lora.multiplier = multipliers[i] if multipliers else 1.0
|
||||
loaded_loras.append(lora)
|
||||
|
||||
|
||||
def lora_forward(module, input, res):
|
||||
if len(loaded_loras) == 0:
|
||||
return res
|
||||
|
||||
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
||||
for lora in loaded_loras:
|
||||
module = lora.modules.get(lora_layer_name, None)
|
||||
if module is not None:
|
||||
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
|
||||
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
else:
|
||||
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def lora_Linear_forward(self, input):
|
||||
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
|
||||
|
||||
|
||||
def lora_Conv2d_forward(self, input):
|
||||
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
|
||||
|
||||
|
||||
def list_available_loras():
|
||||
available_loras.clear()
|
||||
|
||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||
|
||||
candidates = \
|
||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
|
||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
|
||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
|
||||
|
||||
for filename in sorted(candidates):
|
||||
if os.path.isdir(filename):
|
||||
continue
|
||||
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
|
||||
available_loras[name] = LoraOnDisk(name, filename)
|
||||
|
||||
|
||||
available_loras = {}
|
||||
loaded_loras = []
|
||||
|
||||
list_available_loras()
|
6
extensions-builtin/Lora/preload.py
Normal file
6
extensions-builtin/Lora/preload.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
38
extensions-builtin/Lora/scripts/lora_script.py
Normal file
38
extensions-builtin/Lora/scripts/lora_script.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import torch
|
||||
import gradio as gr
|
||||
|
||||
import lora
|
||||
import extra_networks_lora
|
||||
import ui_extra_networks_lora
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||
|
||||
|
||||
def unload():
|
||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
||||
|
||||
|
||||
def before_ui():
|
||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
|
||||
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
||||
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
||||
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
||||
|
||||
torch.nn.Linear.forward = lora.lora_Linear_forward
|
||||
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
||||
|
||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||
script_callbacks.on_script_unloaded(unload)
|
||||
script_callbacks.on_before_ui(before_ui)
|
||||
|
||||
|
||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
||||
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
|
||||
|
||||
}))
|
37
extensions-builtin/Lora/ui_extra_networks_lora.py
Normal file
37
extensions-builtin/Lora/ui_extra_networks_lora.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
import json
|
||||
import os
|
||||
import lora
|
||||
|
||||
from modules import shared, ui_extra_networks
|
||||
|
||||
|
||||
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Lora')
|
||||
|
||||
def refresh(self):
|
||||
lora.list_available_loras()
|
||||
|
||||
def list_items(self):
|
||||
for name, lora_on_disk in lora.available_loras.items():
|
||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||
previews = [path + ".png", path + ".preview.png"]
|
||||
|
||||
preview = None
|
||||
for file in previews:
|
||||
if os.path.isfile(file):
|
||||
preview = self.link_preview(file)
|
||||
break
|
||||
|
||||
yield {
|
||||
"name": name,
|
||||
"filename": path,
|
||||
"preview": preview,
|
||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
||||
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||
"local_preview": path + ".png",
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [shared.cmd_opts.lora_dir]
|
||||
|
|
@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url
|
|||
from tqdm import tqdm
|
||||
|
||||
from modules import modelloader, devices, script_callbacks, shared
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.shared import cmd_opts, opts, state
|
||||
from swinir_model_arch import SwinIR as net
|
||||
from swinir_model_arch_v2 import Swin2SR as net2
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
|||
|
||||
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
||||
for h_idx in h_idx_list:
|
||||
if state.interrupted or state.skipped:
|
||||
break
|
||||
|
||||
for w_idx in w_idx_list:
|
||||
if state.interrupted or state.skipped:
|
||||
break
|
||||
|
||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||
out_patch = model(in_patch)
|
||||
out_patch_mask = torch.ones_like(out_patch)
|
||||
|
|
|
@ -4,16 +4,10 @@
|
|||
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
||||
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
||||
|
||||
function checkBrackets(evt) {
|
||||
textArea = evt.target;
|
||||
tabName = evt.target.parentElement.parentElement.id.split("_")[0];
|
||||
counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
|
||||
|
||||
promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
|
||||
|
||||
errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
|
||||
errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
|
||||
errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
|
||||
function checkBrackets(evt, textArea, counterElt) {
|
||||
errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
|
||||
errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
|
||||
errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
|
||||
|
||||
openBracketRegExp = /\(/g;
|
||||
closeBracketRegExp = /\)/g;
|
||||
|
@ -86,22 +80,31 @@ function checkBrackets(evt) {
|
|||
}
|
||||
|
||||
if(counterElt.title != '') {
|
||||
counterElt.style = 'color: #FF5555;';
|
||||
counterElt.classList.add('error');
|
||||
} else {
|
||||
counterElt.style = '';
|
||||
counterElt.classList.remove('error');
|
||||
}
|
||||
}
|
||||
|
||||
function setupBracketChecking(id_prompt, id_counter){
|
||||
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
|
||||
var counter = gradioApp().getElementById(id_counter)
|
||||
textarea.addEventListener("input", function(evt){
|
||||
checkBrackets(evt, textarea, counter)
|
||||
});
|
||||
}
|
||||
|
||||
var shadowRootLoaded = setInterval(function() {
|
||||
var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
|
||||
if(shadowTextArea.length < 1) {
|
||||
return false;
|
||||
}
|
||||
var shadowRoot = document.querySelector('gradio-app').shadowRoot;
|
||||
if(! shadowRoot) return false;
|
||||
|
||||
clearInterval(shadowRootLoaded);
|
||||
var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
|
||||
if(shadowTextArea.length < 1) return false;
|
||||
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
|
||||
clearInterval(shadowRootLoaded);
|
||||
|
||||
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
|
||||
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
|
||||
setupBracketChecking('img2img_prompt', 'imgimg_token_counter')
|
||||
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
|
||||
}, 1000);
|
||||
|
|
|
@ -1,50 +0,0 @@
|
|||
import random
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
import gradio as gr
|
||||
|
||||
art_symbol = '\U0001f3a8' # 🎨
|
||||
global_prompt = None
|
||||
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
|
||||
|
||||
|
||||
def roll_artist(prompt):
|
||||
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
|
||||
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
|
||||
|
||||
return prompt + ", " + artist.name if prompt != '' else artist.name
|
||||
|
||||
|
||||
def add_roll_button(prompt):
|
||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_txt2img_tokens",
|
||||
inputs=[
|
||||
prompt,
|
||||
],
|
||||
outputs=[
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def after_component(component, **kwargs):
|
||||
global global_prompt
|
||||
|
||||
elem_id = kwargs.get('elem_id', None)
|
||||
if elem_id not in related_ids:
|
||||
return
|
||||
|
||||
if elem_id == "txt2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "txt2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
elif elem_id == "img2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "img2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
|
||||
|
||||
script_callbacks.on_after_component(after_component)
|
BIN
html/card-no-preview.png
Normal file
BIN
html/card-no-preview.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 82 KiB |
12
html/extra-networks-card.html
Normal file
12
html/extra-networks-card.html
Normal file
|
@ -0,0 +1,12 @@
|
|||
<div class='card' {preview_html} onclick={card_clicked}>
|
||||
<div class='actions'>
|
||||
<div class='additional'>
|
||||
<ul>
|
||||
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
||||
</ul>
|
||||
<span style="display:none" class='search_term'>{search_term}</span>
|
||||
</div>
|
||||
<span class='name'>{name}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
8
html/extra-networks-no-cards.html
Normal file
8
html/extra-networks-no-cards.html
Normal file
|
@ -0,0 +1,8 @@
|
|||
<div class='nocards'>
|
||||
<h1>Nothing here. Add some content to the following directories:</h1>
|
||||
|
||||
<ul>
|
||||
{dirs}
|
||||
</ul>
|
||||
</div>
|
||||
|
7
html/image-update.svg
Normal file
7
html/image-update.svg
Normal file
|
@ -0,0 +1,7 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<filter id='shadow' color-interpolation-filters="sRGB">
|
||||
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
|
||||
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
|
||||
</filter>
|
||||
<path style="filter:url(#shadow);" fill="#FFFFFF" d="M13.18 19C13.35 19.72 13.64 20.39 14.03 21H5C3.9 21 3 20.11 3 19V5C3 3.9 3.9 3 5 3H19C20.11 3 21 3.9 21 5V11.18C20.5 11.07 20 11 19.5 11C19.33 11 19.17 11 19 11.03V5H5V19H13.18M11.21 15.83L9.25 13.47L6.5 17H13.03C13.14 15.54 13.73 14.22 14.64 13.19L13.96 12.29L11.21 15.83M19 13.5V12L16.75 14.25L19 16.5V15C20.38 15 21.5 16.12 21.5 17.5C21.5 17.9 21.41 18.28 21.24 18.62L22.33 19.71C22.75 19.08 23 18.32 23 17.5C23 15.29 21.21 13.5 19 13.5M19 20C17.62 20 16.5 18.88 16.5 17.5C16.5 17.1 16.59 16.72 16.76 16.38L15.67 15.29C15.25 15.92 15 16.68 15 17.5C15 19.71 16.79 21.5 19 21.5V23L21.25 20.75L19 18.5V20Z" />
|
||||
</svg>
|
After Width: | Height: | Size: 989 B |
|
@ -21,11 +21,16 @@ function dimensionChange(e, is_width, is_height){
|
|||
var targetElement = null;
|
||||
|
||||
var tabIndex = get_tab_index('mode_img2img')
|
||||
if(tabIndex == 0){
|
||||
if(tabIndex == 0){ // img2img
|
||||
targetElement = gradioApp().querySelector('div[data-testid=image] img');
|
||||
} else if(tabIndex == 1){
|
||||
} else if(tabIndex == 1){ //Sketch
|
||||
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
|
||||
} else if(tabIndex == 2){ // Inpaint
|
||||
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
|
||||
} else if(tabIndex == 3){ // Inpaint sketch
|
||||
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
|
||||
}
|
||||
|
||||
|
||||
if(targetElement){
|
||||
|
||||
|
|
|
@ -1,75 +1,96 @@
|
|||
addEventListener('keydown', (event) => {
|
||||
function keyupEditAttention(event){
|
||||
let target = event.originalTarget || event.composedPath()[0];
|
||||
if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
|
||||
if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
|
||||
if (! (event.metaKey || event.ctrlKey)) return;
|
||||
|
||||
|
||||
let plus = "ArrowUp"
|
||||
let minus = "ArrowDown"
|
||||
if (event.key != plus && event.key != minus) return;
|
||||
let isPlus = event.key == "ArrowUp"
|
||||
let isMinus = event.key == "ArrowDown"
|
||||
if (!isPlus && !isMinus) return;
|
||||
|
||||
let selectionStart = target.selectionStart;
|
||||
let selectionEnd = target.selectionEnd;
|
||||
// If the user hasn't selected anything, let's select their current parenthesis block
|
||||
if (selectionStart === selectionEnd) {
|
||||
let text = target.value;
|
||||
|
||||
function selectCurrentParenthesisBlock(OPEN, CLOSE){
|
||||
if (selectionStart !== selectionEnd) return false;
|
||||
|
||||
// Find opening parenthesis around current cursor
|
||||
const before = target.value.substring(0, selectionStart);
|
||||
let beforeParen = before.lastIndexOf("(");
|
||||
if (beforeParen == -1) return;
|
||||
let beforeParenClose = before.lastIndexOf(")");
|
||||
const before = text.substring(0, selectionStart);
|
||||
let beforeParen = before.lastIndexOf(OPEN);
|
||||
if (beforeParen == -1) return false;
|
||||
let beforeParenClose = before.lastIndexOf(CLOSE);
|
||||
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
||||
beforeParen = before.lastIndexOf("(", beforeParen - 1);
|
||||
beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
|
||||
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
||||
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
|
||||
}
|
||||
|
||||
// Find closing parenthesis around current cursor
|
||||
const after = target.value.substring(selectionStart);
|
||||
let afterParen = after.indexOf(")");
|
||||
if (afterParen == -1) return;
|
||||
let afterParenOpen = after.indexOf("(");
|
||||
const after = text.substring(selectionStart);
|
||||
let afterParen = after.indexOf(CLOSE);
|
||||
if (afterParen == -1) return false;
|
||||
let afterParenOpen = after.indexOf(OPEN);
|
||||
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
||||
afterParen = after.indexOf(")", afterParen + 1);
|
||||
afterParenOpen = after.indexOf("(", afterParenOpen + 1);
|
||||
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
||||
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
|
||||
}
|
||||
if (beforeParen === -1 || afterParen === -1) return;
|
||||
if (beforeParen === -1 || afterParen === -1) return false;
|
||||
|
||||
// Set the selection to the text between the parenthesis
|
||||
const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
|
||||
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
|
||||
const lastColon = parenContent.lastIndexOf(":");
|
||||
selectionStart = beforeParen + 1;
|
||||
selectionEnd = selectionStart + lastColon;
|
||||
target.setSelectionRange(selectionStart, selectionEnd);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the user hasn't selected anything, let's select their current parenthesis block
|
||||
if(! selectCurrentParenthesisBlock('<', '>')){
|
||||
selectCurrentParenthesisBlock('(', ')')
|
||||
}
|
||||
|
||||
event.preventDefault();
|
||||
|
||||
if (selectionStart == 0 || target.value[selectionStart - 1] != "(") {
|
||||
target.value = target.value.slice(0, selectionStart) +
|
||||
"(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
|
||||
target.value.slice(selectionEnd);
|
||||
closeCharacter = ')'
|
||||
delta = opts.keyedit_precision_attention
|
||||
|
||||
target.focus();
|
||||
target.selectionStart = selectionStart + 1;
|
||||
target.selectionEnd = selectionEnd + 1;
|
||||
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
|
||||
closeCharacter = '>'
|
||||
delta = opts.keyedit_precision_extra
|
||||
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
|
||||
|
||||
} else {
|
||||
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
|
||||
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
||||
if (isNaN(weight)) return;
|
||||
if (event.key == minus) weight -= 0.1;
|
||||
if (event.key == plus) weight += 0.1;
|
||||
// do not include spaces at the end
|
||||
while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
|
||||
selectionEnd -= 1;
|
||||
}
|
||||
if(selectionStart == selectionEnd){
|
||||
return
|
||||
}
|
||||
|
||||
weight = parseFloat(weight.toPrecision(12));
|
||||
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
|
||||
|
||||
target.value = target.value.slice(0, selectionEnd + 1) +
|
||||
weight +
|
||||
target.value.slice(selectionEnd + 1 + end - 1);
|
||||
selectionStart += 1;
|
||||
selectionEnd += 1;
|
||||
}
|
||||
|
||||
target.focus();
|
||||
target.selectionStart = selectionStart;
|
||||
target.selectionEnd = selectionEnd;
|
||||
}
|
||||
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
|
||||
// internal Svelte data binding remains in sync.
|
||||
target.dispatchEvent(new Event("input", { bubbles: true }));
|
||||
});
|
||||
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
||||
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
||||
if (isNaN(weight)) return;
|
||||
|
||||
weight += isPlus ? delta : -delta;
|
||||
weight = parseFloat(weight.toPrecision(12));
|
||||
if(String(weight).length == 1) weight += ".0"
|
||||
|
||||
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
|
||||
|
||||
target.focus();
|
||||
target.value = text;
|
||||
target.selectionStart = selectionStart;
|
||||
target.selectionEnd = selectionEnd;
|
||||
|
||||
updateInput(target)
|
||||
}
|
||||
|
||||
addEventListener('keydown', (event) => {
|
||||
keyupEditAttention(event);
|
||||
});
|
|
@ -1,7 +1,8 @@
|
|||
|
||||
function extensions_apply(_, _){
|
||||
disable = []
|
||||
update = []
|
||||
var disable = []
|
||||
var update = []
|
||||
|
||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||
if(x.name.startsWith("enable_") && ! x.checked)
|
||||
disable.push(x.name.substr(7))
|
||||
|
@ -16,11 +17,24 @@ function extensions_apply(_, _){
|
|||
}
|
||||
|
||||
function extensions_check(){
|
||||
var disable = []
|
||||
|
||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||
if(x.name.startsWith("enable_") && ! x.checked)
|
||||
disable.push(x.name.substr(7))
|
||||
})
|
||||
|
||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||
x.innerHTML = "Loading..."
|
||||
})
|
||||
|
||||
return []
|
||||
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
|
||||
|
||||
})
|
||||
|
||||
return [id, JSON.stringify(disable)]
|
||||
}
|
||||
|
||||
function install_extension_from_index(button, url){
|
||||
|
@ -29,7 +43,7 @@ function install_extension_from_index(button, url){
|
|||
|
||||
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
||||
textarea.value = url
|
||||
textarea.dispatchEvent(new Event("input", { bubbles: true }))
|
||||
updateInput(textarea)
|
||||
|
||||
gradioApp().querySelector('#install_extension_button').click()
|
||||
}
|
||||
|
|
107
javascript/extraNetworks.js
Normal file
107
javascript/extraNetworks.js
Normal file
|
@ -0,0 +1,107 @@
|
|||
|
||||
function setupExtraNetworksForTab(tabname){
|
||||
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
||||
|
||||
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
|
||||
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
|
||||
var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
|
||||
var close = gradioApp().getElementById(tabname+'_extra_close')
|
||||
|
||||
search.classList.add('search')
|
||||
tabs.appendChild(search)
|
||||
tabs.appendChild(refresh)
|
||||
tabs.appendChild(close)
|
||||
|
||||
search.addEventListener("input", function(evt){
|
||||
searchTerm = search.value.toLowerCase()
|
||||
|
||||
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
||||
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
||||
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
var activePromptTextarea = {};
|
||||
|
||||
function setupExtraNetworks(){
|
||||
setupExtraNetworksForTab('txt2img')
|
||||
setupExtraNetworksForTab('img2img')
|
||||
|
||||
function registerPrompt(tabname, id){
|
||||
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
||||
|
||||
if (! activePromptTextarea[tabname]){
|
||||
activePromptTextarea[tabname] = textarea
|
||||
}
|
||||
|
||||
textarea.addEventListener("focus", function(){
|
||||
activePromptTextarea[tabname] = textarea;
|
||||
});
|
||||
}
|
||||
|
||||
registerPrompt('txt2img', 'txt2img_prompt')
|
||||
registerPrompt('txt2img', 'txt2img_neg_prompt')
|
||||
registerPrompt('img2img', 'img2img_prompt')
|
||||
registerPrompt('img2img', 'img2img_neg_prompt')
|
||||
}
|
||||
|
||||
onUiLoaded(setupExtraNetworks)
|
||||
|
||||
var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
|
||||
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
|
||||
|
||||
function tryToRemoveExtraNetworkFromPrompt(textarea, text){
|
||||
var m = text.match(re_extranet)
|
||||
if(! m) return false
|
||||
|
||||
var partToSearch = m[1]
|
||||
var replaced = false
|
||||
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
|
||||
m = found.match(re_extranet);
|
||||
if(m[1] == partToSearch){
|
||||
replaced = true;
|
||||
return ""
|
||||
}
|
||||
return found;
|
||||
})
|
||||
|
||||
if(replaced){
|
||||
textarea.value = newTextareaText
|
||||
return true;
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
function cardClicked(tabname, textToAdd, allowNegativePrompt){
|
||||
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
|
||||
|
||||
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
|
||||
textarea.value = textarea.value + " " + textToAdd
|
||||
}
|
||||
|
||||
updateInput(textarea)
|
||||
}
|
||||
|
||||
function saveCardPreview(event, tabname, filename){
|
||||
var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
|
||||
var button = gradioApp().getElementById(tabname + '_save_preview')
|
||||
|
||||
textarea.value = filename
|
||||
updateInput(textarea)
|
||||
|
||||
button.click()
|
||||
|
||||
event.stopPropagation()
|
||||
event.preventDefault()
|
||||
}
|
||||
|
||||
function extraNetworksSearchButton(tabs_id, event){
|
||||
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
||||
button = event.target
|
||||
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
||||
|
||||
searchTextarea.value = text
|
||||
updateInput(searchTextarea)
|
||||
}
|
|
@ -4,7 +4,7 @@ titles = {
|
|||
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
|
||||
"Sampling method": "Which algorithm to use to produce the image",
|
||||
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
|
||||
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help",
|
||||
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
|
||||
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
|
||||
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
|
||||
|
||||
|
@ -14,12 +14,14 @@ titles = {
|
|||
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
||||
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
||||
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
||||
"\u{1f3a8}": "Add a random artist to the prompt.",
|
||||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||
"\u{1f4c2}": "Open images output directory",
|
||||
"\u{1f4be}": "Save style",
|
||||
"\U0001F5D1": "Clear prompt",
|
||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||
"\u{1f4d2}": "Paste available values into the field",
|
||||
"\u{1f3b4}": "Show extra networks",
|
||||
|
||||
|
||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
||||
|
@ -48,7 +50,7 @@ titles = {
|
|||
|
||||
"None": "Do not do anything special",
|
||||
"Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)",
|
||||
"X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
|
||||
"X/Y/Z plot": "Create grid(s) where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
|
||||
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
|
||||
|
||||
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
|
||||
|
@ -74,7 +76,7 @@ titles = {
|
|||
"Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
|
||||
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
|
||||
"Apply style": "Insert selected styles into prompt fields",
|
||||
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
|
||||
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style uses that as a placeholder for your prompt when you use the style in the future.",
|
||||
|
||||
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
|
||||
"Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
|
||||
|
@ -91,19 +93,24 @@ titles = {
|
|||
|
||||
"Weighted sum": "Result = A * (1 - M) + B * M",
|
||||
"Add difference": "Result = A + (B - C) * M",
|
||||
"No interpolation": "Result = A",
|
||||
|
||||
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
|
||||
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||
|
||||
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
||||
|
||||
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
|
||||
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality.",
|
||||
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resolution and lower quality.",
|
||||
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resolution and extremely low quality.",
|
||||
|
||||
"Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
|
||||
"Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.",
|
||||
"Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.",
|
||||
"Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.",
|
||||
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders."
|
||||
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
|
||||
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
|
||||
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
|
||||
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited."
|
||||
}
|
||||
|
||||
|
||||
|
|
22
javascript/hires_fix.js
Normal file
22
javascript/hires_fix.js
Normal file
|
@ -0,0 +1,22 @@
|
|||
|
||||
function setInactive(elem, inactive){
|
||||
if(inactive){
|
||||
elem.classList.add('inactive')
|
||||
} else{
|
||||
elem.classList.remove('inactive')
|
||||
}
|
||||
}
|
||||
|
||||
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
||||
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
||||
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
||||
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
||||
|
||||
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
||||
|
||||
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
|
||||
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
|
||||
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
|
||||
|
||||
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
||||
}
|
|
@ -148,9 +148,18 @@ function showGalleryImage() {
|
|||
if(e && e.parentElement.tagName == 'DIV'){
|
||||
e.style.cursor='pointer'
|
||||
e.style.userSelect='none'
|
||||
e.addEventListener('mousedown', function (evt) {
|
||||
|
||||
var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
|
||||
|
||||
// For Firefox, listening on click first switched to next image then shows the lightbox.
|
||||
// If you know how to fix this without switching to mousedown event, please.
|
||||
// For other browsers the event is click to make it possiblr to drag picture.
|
||||
var event = isFirefox ? 'mousedown' : 'click'
|
||||
|
||||
e.addEventListener(event, function (evt) {
|
||||
if(!opts.js_modal_lightbox || evt.button != 0) return;
|
||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
||||
evt.preventDefault()
|
||||
showModal(evt)
|
||||
}, true);
|
||||
}
|
||||
|
|
|
@ -10,10 +10,8 @@ ignore_ids_for_localization={
|
|||
modelmerger_tertiary_model_name: 'OPTION',
|
||||
train_embedding: 'OPTION',
|
||||
train_hypernetwork: 'OPTION',
|
||||
txt2img_style_index: 'OPTION',
|
||||
txt2img_style2_index: 'OPTION',
|
||||
img2img_style_index: 'OPTION',
|
||||
img2img_style2_index: 'OPTION',
|
||||
txt2img_styles: 'OPTION',
|
||||
img2img_styles: 'OPTION',
|
||||
setting_random_artist_categories: 'SPAN',
|
||||
setting_face_restoration_model: 'SPAN',
|
||||
setting_realesrgan_enabled_models: 'SPAN',
|
||||
|
|
|
@ -1,82 +1,25 @@
|
|||
// code related to showing and updating progressbar shown as the image is being made
|
||||
global_progressbars = {}
|
||||
|
||||
|
||||
galleries = {}
|
||||
storedGallerySelections = {}
|
||||
galleryObservers = {}
|
||||
|
||||
// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
|
||||
timeoutIds = {}
|
||||
|
||||
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
||||
// gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
|
||||
// every time you use gr.HTML(elem_id='xxx'), so we handle this here
|
||||
var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
|
||||
var progressbarParent
|
||||
if(progressbar){
|
||||
progressbarParent = gradioApp().querySelector("#"+id_progressbar)
|
||||
} else{
|
||||
progressbar = gradioApp().getElementById(id_progressbar)
|
||||
progressbarParent = null
|
||||
}
|
||||
|
||||
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
||||
var interrupt = gradioApp().getElementById(id_interrupt)
|
||||
|
||||
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
|
||||
if(progressbar.innerText){
|
||||
let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
|
||||
if(document.title != newtitle){
|
||||
document.title = newtitle;
|
||||
}
|
||||
}else{
|
||||
let newtitle = 'Stable Diffusion'
|
||||
if(document.title != newtitle){
|
||||
document.title = newtitle;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
|
||||
global_progressbars[id_progressbar] = progressbar
|
||||
|
||||
var mutationObserver = new MutationObserver(function(m){
|
||||
if(timeoutIds[id_part]) return;
|
||||
|
||||
preview = gradioApp().getElementById(id_preview)
|
||||
gallery = gradioApp().getElementById(id_gallery)
|
||||
|
||||
if(preview != null && gallery != null){
|
||||
preview.style.width = gallery.clientWidth + "px"
|
||||
preview.style.height = gallery.clientHeight + "px"
|
||||
if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
|
||||
|
||||
//only watch gallery if there is a generation process going on
|
||||
check_gallery(id_gallery);
|
||||
|
||||
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
||||
if(progressDiv){
|
||||
timeoutIds[id_part] = window.setTimeout(function() {
|
||||
timeoutIds[id_part] = null
|
||||
requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
|
||||
}, 500)
|
||||
} else{
|
||||
if (skip) {
|
||||
skip.style.display = "none"
|
||||
}
|
||||
interrupt.style.display = "none"
|
||||
|
||||
//disconnect observer once generation finished, so user can close selected image if they want
|
||||
if (galleryObservers[id_gallery]) {
|
||||
galleryObservers[id_gallery].disconnect();
|
||||
galleries[id_gallery] = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
||||
}
|
||||
function rememberGallerySelection(id_gallery){
|
||||
storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery)
|
||||
}
|
||||
|
||||
function getGallerySelectedIndex(id_gallery){
|
||||
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
||||
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
||||
|
||||
let currentlySelectedIndex = -1
|
||||
galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } })
|
||||
|
||||
return currentlySelectedIndex
|
||||
}
|
||||
|
||||
// this is a workaround for https://github.com/gradio-app/gradio/issues/2984
|
||||
function check_gallery(id_gallery){
|
||||
let gallery = gradioApp().getElementById(id_gallery)
|
||||
// if gallery has no change, no need to setting up observer again.
|
||||
|
@ -85,10 +28,16 @@ function check_gallery(id_gallery){
|
|||
if(galleryObservers[id_gallery]){
|
||||
galleryObservers[id_gallery].disconnect();
|
||||
}
|
||||
let prevSelectedIndex = selected_gallery_index();
|
||||
|
||||
storedGallerySelections[id_gallery] = -1
|
||||
|
||||
galleryObservers[id_gallery] = new MutationObserver(function (){
|
||||
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
||||
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
||||
let currentlySelectedIndex = getGallerySelectedIndex(id_gallery)
|
||||
prevSelectedIndex = storedGallerySelections[id_gallery]
|
||||
storedGallerySelections[id_gallery] = -1
|
||||
|
||||
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
|
||||
// automatically re-open previously selected index (if exists)
|
||||
activeElement = gradioApp().activeElement;
|
||||
|
@ -120,30 +69,175 @@ function check_gallery(id_gallery){
|
|||
}
|
||||
|
||||
onUiUpdate(function(){
|
||||
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
||||
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
||||
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
|
||||
check_gallery('txt2img_gallery')
|
||||
check_gallery('img2img_gallery')
|
||||
})
|
||||
|
||||
function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
|
||||
btn = gradioApp().getElementById(id_part+"_check_progress");
|
||||
if(btn==null) return;
|
||||
|
||||
btn.click();
|
||||
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
||||
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
||||
var interrupt = gradioApp().getElementById(id_interrupt)
|
||||
if(progressDiv && interrupt){
|
||||
if (skip) {
|
||||
skip.style.display = "block"
|
||||
function request(url, data, handler, errorHandler){
|
||||
var xhr = new XMLHttpRequest();
|
||||
var url = url;
|
||||
xhr.open("POST", url, true);
|
||||
xhr.setRequestHeader("Content-Type", "application/json");
|
||||
xhr.onreadystatechange = function () {
|
||||
if (xhr.readyState === 4) {
|
||||
if (xhr.status === 200) {
|
||||
try {
|
||||
var js = JSON.parse(xhr.responseText);
|
||||
handler(js)
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
errorHandler()
|
||||
}
|
||||
} else{
|
||||
errorHandler()
|
||||
}
|
||||
}
|
||||
interrupt.style.display = "block"
|
||||
};
|
||||
var js = JSON.stringify(data);
|
||||
xhr.send(js);
|
||||
}
|
||||
|
||||
function pad2(x){
|
||||
return x<10 ? '0'+x : x
|
||||
}
|
||||
|
||||
function formatTime(secs){
|
||||
if(secs > 3600){
|
||||
return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60)
|
||||
} else if(secs > 60){
|
||||
return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60)
|
||||
} else{
|
||||
return Math.floor(secs) + "s"
|
||||
}
|
||||
}
|
||||
|
||||
function requestProgress(id_part){
|
||||
btn = gradioApp().getElementById(id_part+"_check_progress_initial");
|
||||
if(btn==null) return;
|
||||
function setTitle(progress){
|
||||
var title = 'Stable Diffusion'
|
||||
|
||||
btn.click();
|
||||
if(opts.show_progress_in_title && progress){
|
||||
title = '[' + progress.trim() + '] ' + title;
|
||||
}
|
||||
|
||||
if(document.title != title){
|
||||
document.title = title;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function randomId(){
|
||||
return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
|
||||
}
|
||||
|
||||
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
|
||||
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
|
||||
// calls onProgress every time there is a progress update
|
||||
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){
|
||||
var dateStart = new Date()
|
||||
var wasEverActive = false
|
||||
var parentProgressbar = progressbarContainer.parentNode
|
||||
var parentGallery = gallery ? gallery.parentNode : null
|
||||
|
||||
var divProgress = document.createElement('div')
|
||||
divProgress.className='progressDiv'
|
||||
divProgress.style.display = opts.show_progressbar ? "" : "none"
|
||||
var divInner = document.createElement('div')
|
||||
divInner.className='progress'
|
||||
|
||||
divProgress.appendChild(divInner)
|
||||
parentProgressbar.insertBefore(divProgress, progressbarContainer)
|
||||
|
||||
if(parentGallery){
|
||||
var livePreview = document.createElement('div')
|
||||
livePreview.className='livePreview'
|
||||
parentGallery.insertBefore(livePreview, gallery)
|
||||
}
|
||||
|
||||
var removeProgressBar = function(){
|
||||
setTitle("")
|
||||
parentProgressbar.removeChild(divProgress)
|
||||
if(parentGallery) parentGallery.removeChild(livePreview)
|
||||
atEnd()
|
||||
}
|
||||
|
||||
var fun = function(id_task, id_live_preview){
|
||||
request("./internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
|
||||
if(res.completed){
|
||||
removeProgressBar()
|
||||
return
|
||||
}
|
||||
|
||||
var rect = progressbarContainer.getBoundingClientRect()
|
||||
|
||||
if(rect.width){
|
||||
divProgress.style.width = rect.width + "px";
|
||||
}
|
||||
|
||||
progressText = ""
|
||||
|
||||
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
|
||||
divInner.style.background = res.progress ? "" : "transparent"
|
||||
|
||||
if(res.progress > 0){
|
||||
progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
|
||||
}
|
||||
|
||||
if(res.eta){
|
||||
progressText += " ETA: " + formatTime(res.eta)
|
||||
}
|
||||
|
||||
|
||||
setTitle(progressText)
|
||||
|
||||
if(res.textinfo && res.textinfo.indexOf("\n") == -1){
|
||||
progressText = res.textinfo + " " + progressText
|
||||
}
|
||||
|
||||
divInner.textContent = progressText
|
||||
|
||||
var elapsedFromStart = (new Date() - dateStart) / 1000
|
||||
|
||||
if(res.active) wasEverActive = true;
|
||||
|
||||
if(! res.active && wasEverActive){
|
||||
removeProgressBar()
|
||||
return
|
||||
}
|
||||
|
||||
if(elapsedFromStart > 5 && !res.queued && !res.active){
|
||||
removeProgressBar()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
if(res.live_preview && gallery){
|
||||
var rect = gallery.getBoundingClientRect()
|
||||
if(rect.width){
|
||||
livePreview.style.width = rect.width + "px"
|
||||
livePreview.style.height = rect.height + "px"
|
||||
}
|
||||
|
||||
var img = new Image();
|
||||
img.onload = function() {
|
||||
livePreview.appendChild(img)
|
||||
if(livePreview.childElementCount > 2){
|
||||
livePreview.removeChild(livePreview.firstElementChild)
|
||||
}
|
||||
}
|
||||
img.src = res.live_preview;
|
||||
}
|
||||
|
||||
|
||||
if(onProgress){
|
||||
onProgress(res)
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
fun(id_task, res.id_live_preview);
|
||||
}, opts.live_preview_refresh_period || 500)
|
||||
}, function(){
|
||||
removeProgressBar()
|
||||
})
|
||||
}
|
||||
|
||||
fun(id_task, 0)
|
||||
}
|
||||
|
|
|
@ -1,8 +1,17 @@
|
|||
|
||||
|
||||
|
||||
function start_training_textual_inversion(){
|
||||
requestProgress('ti')
|
||||
gradioApp().querySelector('#ti_error').innerHTML=''
|
||||
|
||||
return args_to_array(arguments)
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
|
||||
gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
|
||||
})
|
||||
|
||||
var res = args_to_array(arguments)
|
||||
|
||||
res[0] = id
|
||||
|
||||
return res
|
||||
}
|
||||
|
|
151
javascript/ui.js
151
javascript/ui.js
|
@ -45,16 +45,33 @@ function switch_to_txt2img(){
|
|||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
function switch_to_img2img(){
|
||||
function switch_to_img2img_tab(no){
|
||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
|
||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[no].click();
|
||||
}
|
||||
function switch_to_img2img(){
|
||||
switch_to_img2img_tab(0);
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
function switch_to_sketch(){
|
||||
switch_to_img2img_tab(1);
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
function switch_to_inpaint(){
|
||||
switch_to_img2img_tab(2);
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
function switch_to_inpaint_sketch(){
|
||||
switch_to_img2img_tab(3);
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
function switch_to_inpaint(){
|
||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
|
||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[2].click();
|
||||
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
@ -87,9 +104,11 @@ function create_tab_index_args(tabId, args){
|
|||
return res
|
||||
}
|
||||
|
||||
function get_extras_tab_index(){
|
||||
const [,,...args] = [...arguments]
|
||||
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
|
||||
function get_img2img_tab_index() {
|
||||
let res = args_to_array(arguments)
|
||||
res.splice(-2)
|
||||
res[0] = get_tab_index('mode_img2img')
|
||||
return res
|
||||
}
|
||||
|
||||
function create_submit_args(args){
|
||||
|
@ -109,19 +128,51 @@ function create_submit_args(args){
|
|||
return res
|
||||
}
|
||||
|
||||
function submit(){
|
||||
requestProgress('txt2img')
|
||||
function showSubmitButtons(tabname, show){
|
||||
gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block"
|
||||
gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
|
||||
}
|
||||
|
||||
return create_submit_args(arguments)
|
||||
function submit(){
|
||||
rememberGallerySelection('txt2img_gallery')
|
||||
showSubmitButtons('txt2img', false)
|
||||
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
|
||||
showSubmitButtons('txt2img', true)
|
||||
|
||||
})
|
||||
|
||||
var res = create_submit_args(arguments)
|
||||
|
||||
res[0] = id
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
function submit_img2img(){
|
||||
requestProgress('img2img')
|
||||
rememberGallerySelection('img2img_gallery')
|
||||
showSubmitButtons('img2img', false)
|
||||
|
||||
res = create_submit_args(arguments)
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
||||
showSubmitButtons('img2img', true)
|
||||
})
|
||||
|
||||
res[0] = get_tab_index('mode_img2img')
|
||||
var res = create_submit_args(arguments)
|
||||
|
||||
res[0] = id
|
||||
res[1] = get_tab_index('mode_img2img')
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
function modelmerger(){
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
|
||||
|
||||
var res = create_submit_args(arguments)
|
||||
res[0] = id
|
||||
return res
|
||||
}
|
||||
|
||||
|
@ -140,27 +191,17 @@ function confirm_clear_prompt(prompt, negative_prompt) {
|
|||
return [prompt, negative_prompt]
|
||||
}
|
||||
|
||||
|
||||
|
||||
opts = {}
|
||||
function apply_settings(jsdata){
|
||||
console.log(jsdata)
|
||||
|
||||
opts = JSON.parse(jsdata)
|
||||
|
||||
return jsdata
|
||||
}
|
||||
|
||||
onUiUpdate(function(){
|
||||
if(Object.keys(opts).length != 0) return;
|
||||
|
||||
json_elem = gradioApp().getElementById('settings_json')
|
||||
if(json_elem == null) return;
|
||||
|
||||
textarea = json_elem.querySelector('textarea')
|
||||
jsdata = textarea.value
|
||||
var textarea = json_elem.querySelector('textarea')
|
||||
var jsdata = textarea.value
|
||||
opts = JSON.parse(jsdata)
|
||||
|
||||
executeCallbacks(optionsChangedCallbacks);
|
||||
|
||||
Object.defineProperty(textarea, 'value', {
|
||||
set: function(newValue) {
|
||||
|
@ -171,6 +212,8 @@ onUiUpdate(function(){
|
|||
if (oldValue != newValue) {
|
||||
opts = JSON.parse(textarea.value)
|
||||
}
|
||||
|
||||
executeCallbacks(optionsChangedCallbacks);
|
||||
},
|
||||
get: function() {
|
||||
var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
|
||||
|
@ -180,14 +223,29 @@ onUiUpdate(function(){
|
|||
|
||||
json_elem.parentElement.style.display="none"
|
||||
|
||||
if (!txt2img_textarea) {
|
||||
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
||||
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
||||
}
|
||||
if (!img2img_textarea) {
|
||||
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
||||
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
||||
}
|
||||
function registerTextarea(id, id_counter, id_button){
|
||||
var prompt = gradioApp().getElementById(id)
|
||||
var counter = gradioApp().getElementById(id_counter)
|
||||
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
||||
|
||||
if(counter.parentElement == prompt.parentElement){
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
prompt.parentElement.insertBefore(counter, prompt)
|
||||
counter.classList.add("token-counter")
|
||||
prompt.parentElement.style.position = "relative"
|
||||
|
||||
textarea.addEventListener("input", function(){
|
||||
update_token_counter(id_button);
|
||||
});
|
||||
}
|
||||
|
||||
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
|
||||
registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button')
|
||||
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
|
||||
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
|
||||
|
||||
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
|
||||
settings_tabs = gradioApp().querySelector('#settings div')
|
||||
|
@ -201,6 +259,18 @@ onUiUpdate(function(){
|
|||
}
|
||||
})
|
||||
|
||||
onOptionsChanged(function(){
|
||||
elem = gradioApp().getElementById('sd_checkpoint_hash')
|
||||
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
||||
shorthash = sd_checkpoint_hash.substr(0,10)
|
||||
|
||||
if(elem && elem.textContent != shorthash){
|
||||
elem.textContent = shorthash
|
||||
elem.title = sd_checkpoint_hash
|
||||
elem.href = "https://google.com/search?q=" + sd_checkpoint_hash
|
||||
}
|
||||
})
|
||||
|
||||
let txt2img_textarea, img2img_textarea = undefined;
|
||||
let wait_time = 800
|
||||
let token_timeout;
|
||||
|
@ -231,3 +301,18 @@ function restart_reload(){
|
|||
|
||||
return []
|
||||
}
|
||||
|
||||
// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits
|
||||
// will only visible on web page and not sent to python.
|
||||
function updateInput(target){
|
||||
let e = new Event("input", { bubbles: true })
|
||||
Object.defineProperty(e, "target", {value: target})
|
||||
target.dispatchEvent(e);
|
||||
}
|
||||
|
||||
|
||||
var desiredCheckpointName = null;
|
||||
function selectCheckpoint(name){
|
||||
desiredCheckpointName = name;
|
||||
gradioApp().getElementById('change_checkpoint').click()
|
||||
}
|
||||
|
|
79
launch.py
79
launch.py
|
@ -14,6 +14,38 @@ python = sys.executable
|
|||
git = os.environ.get('GIT', "git")
|
||||
index_url = os.environ.get('INDEX_URL', "")
|
||||
stored_commit_hash = None
|
||||
skip_install = False
|
||||
|
||||
|
||||
def check_python_version():
|
||||
is_windows = platform.system() == "Windows"
|
||||
major = sys.version_info.major
|
||||
minor = sys.version_info.minor
|
||||
micro = sys.version_info.micro
|
||||
|
||||
if is_windows:
|
||||
supported_minors = [10]
|
||||
else:
|
||||
supported_minors = [7, 8, 9, 10, 11]
|
||||
|
||||
if not (major == 3 and minor in supported_minors):
|
||||
import modules.errors
|
||||
|
||||
modules.errors.print_error_explanation(f"""
|
||||
INCOMPATIBLE PYTHON VERSION
|
||||
|
||||
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
||||
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
||||
or any other error regarding unsuccessful package (library) installation,
|
||||
please downgrade (or upgrade) to the latest version of 3.10 Python
|
||||
and delete current Python and "venv" folder in WebUI's directory.
|
||||
|
||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
|
||||
|
||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
||||
|
||||
Use --skip-python-version-check to suppress this warning.
|
||||
""")
|
||||
|
||||
|
||||
def commit_hash():
|
||||
|
@ -47,10 +79,19 @@ def extract_opt(args, name):
|
|||
return args, is_present, opt
|
||||
|
||||
|
||||
def run(command, desc=None, errdesc=None, custom_env=None):
|
||||
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
||||
if desc is not None:
|
||||
print(desc)
|
||||
|
||||
if live:
|
||||
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
||||
Command: {command}
|
||||
Error code: {result.returncode}""")
|
||||
|
||||
return ""
|
||||
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
||||
|
||||
if result.returncode != 0:
|
||||
|
@ -89,6 +130,9 @@ def run_python(code, desc=None, errdesc=None):
|
|||
|
||||
|
||||
def run_pip(args, desc=None):
|
||||
if skip_install:
|
||||
return
|
||||
|
||||
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
||||
|
||||
|
@ -104,18 +148,18 @@ def git_clone(url, dir, name, commithash=None):
|
|||
if commithash is None:
|
||||
return
|
||||
|
||||
current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
||||
if current_hash == commithash:
|
||||
return
|
||||
|
||||
run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
||||
return
|
||||
|
||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
||||
|
||||
if commithash is not None:
|
||||
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||
|
||||
|
||||
def version_check(commit):
|
||||
|
@ -173,7 +217,9 @@ def run_extensions_installers(settings_file):
|
|||
|
||||
|
||||
def prepare_environment():
|
||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
||||
global skip_install
|
||||
|
||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
|
||||
|
@ -181,8 +227,6 @@ def prepare_environment():
|
|||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
||||
|
||||
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
|
||||
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
|
@ -203,19 +247,25 @@ def prepare_environment():
|
|||
|
||||
sys.argv, _ = extract_arg(sys.argv, '-f')
|
||||
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
||||
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
|
||||
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
||||
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
|
||||
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
|
||||
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
|
||||
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
|
||||
xformers = '--xformers' in sys.argv
|
||||
ngrok = '--ngrok' in sys.argv
|
||||
|
||||
if not skip_python_version_check:
|
||||
check_python_version()
|
||||
|
||||
commit = commit_hash()
|
||||
|
||||
print(f"Python {sys.version}")
|
||||
print(f"Commit hash: {commit}")
|
||||
|
||||
if not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
|
||||
|
||||
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||
|
||||
if not skip_torch_cuda_test:
|
||||
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
||||
|
@ -232,14 +282,14 @@ def prepare_environment():
|
|||
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
||||
if platform.system() == "Windows":
|
||||
if platform.python_version().startswith("3.10"):
|
||||
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
|
||||
run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers")
|
||||
else:
|
||||
print("Installation of xformers is not supported in this version of Python.")
|
||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||
if not is_installed("xformers"):
|
||||
exit(0)
|
||||
elif platform.system() == "Linux":
|
||||
run_pip("install xformers", "xformers")
|
||||
run_pip("install xformers==0.0.16rc425", "xformers")
|
||||
|
||||
if not is_installed("pyngrok") and ngrok:
|
||||
run_pip("install pyngrok", "ngrok")
|
||||
|
@ -279,9 +329,12 @@ def tests(test_dir):
|
|||
sys.argv.append("./test/test_files/empty.pt")
|
||||
if "--skip-torch-cuda-test" not in sys.argv:
|
||||
sys.argv.append("--skip-torch-cuda-test")
|
||||
if "--disable-nan-check" not in sys.argv:
|
||||
sys.argv.append("--disable-nan-check")
|
||||
|
||||
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
|
||||
|
||||
os.environ['COMMANDLINE_ARGS'] = ""
|
||||
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
|
||||
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
|
||||
|
||||
|
|
|
@ -11,18 +11,20 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
||||
from modules.api.models import *
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.extras import run_extras
|
||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list, find_checkpoint_config
|
||||
from modules.sd_models import checkpoints_list
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
import piexif
|
||||
import piexif.helper
|
||||
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
|
@ -45,32 +47,46 @@ def validate_sampler_name(name):
|
|||
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
|
||||
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
|
||||
reqDict.pop('upscaler_1')
|
||||
reqDict.pop('upscaler_2')
|
||||
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||
return reqDict
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
return Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
|
||||
# Copy any text-only metadata
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
if opts.samples_format.lower() == 'png':
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
|
||||
|
||||
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
|
||||
parameters = image.info.get('parameters', None)
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
|
||||
})
|
||||
if opts.samples_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
|
||||
else:
|
||||
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid image format")
|
||||
|
||||
image.save(
|
||||
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
|
@ -126,8 +142,6 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
||||
|
@ -135,6 +149,7 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
if shared.cmd_opts.api_auth:
|
||||
|
@ -245,7 +260,7 @@ class Api:
|
|||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||
|
||||
with self.queue_lock:
|
||||
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||
|
||||
|
@ -261,7 +276,7 @@ class Api:
|
|||
reqDict.pop('imageList')
|
||||
|
||||
with self.queue_lock:
|
||||
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
|
@ -285,7 +300,7 @@ class Api:
|
|||
# copy from check_progress_call of ui.py
|
||||
|
||||
if shared.state.job_count == 0:
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||
|
||||
# avoid dividing zero
|
||||
progress = 0.01
|
||||
|
@ -307,7 +322,7 @@ class Api:
|
|||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||
|
||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
|
@ -361,16 +376,19 @@ class Api:
|
|||
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
||||
|
||||
def get_upscalers(self):
|
||||
upscalers = []
|
||||
|
||||
for upscaler in shared.sd_upscalers:
|
||||
u = upscaler.scaler
|
||||
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
|
||||
|
||||
return upscalers
|
||||
return [
|
||||
{
|
||||
"name": upscaler.name,
|
||||
"model_name": upscaler.scaler.model_name,
|
||||
"model_path": upscaler.data_path,
|
||||
"model_url": None,
|
||||
"scale": upscaler.scale,
|
||||
}
|
||||
for upscaler in shared.sd_upscalers
|
||||
]
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
|
@ -389,12 +407,6 @@ class Api:
|
|||
|
||||
return styleList
|
||||
|
||||
def get_artists_categories(self):
|
||||
return shared.artist_db.cats
|
||||
|
||||
def get_artists(self):
|
||||
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
|
||||
|
||||
def get_embeddings(self):
|
||||
db = sd_hijack.model_hijack.embedding_db
|
||||
|
||||
|
@ -479,7 +491,7 @@ class Api:
|
|||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
shared.loaded_hypernetworks = []
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
|
@ -490,16 +502,49 @@ class Api:
|
|||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
||||
return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding error: {error}".format(error = error))
|
||||
return TrainResponse(info="train embedding error: {error}".format(error=error))
|
||||
|
||||
def get_memory(self):
|
||||
try:
|
||||
import os, psutil
|
||||
process = psutil.Process(os.getpid())
|
||||
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
||||
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
||||
ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
|
||||
except Exception as err:
|
||||
ram = { 'error': f'{err}' }
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
s = torch.cuda.mem_get_info()
|
||||
system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
|
||||
s = dict(torch.cuda.memory_stats(shared.device))
|
||||
allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
|
||||
reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
|
||||
active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
|
||||
inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
|
||||
warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
|
||||
cuda = {
|
||||
'system': system,
|
||||
'active': active,
|
||||
'allocated': allocated,
|
||||
'reserved': reserved,
|
||||
'inactive': inactive,
|
||||
'events': warnings,
|
||||
}
|
||||
else:
|
||||
cuda = { 'error': 'unavailable' }
|
||||
except Exception as err:
|
||||
cuda = { 'error': f'{err}' }
|
||||
return MemoryResponse(ram = ram, cuda = cuda)
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
|
|
|
@ -168,6 +168,7 @@ class ProgressResponse(BaseModel):
|
|||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
||||
|
||||
class InterrogateRequest(BaseModel):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
|
@ -219,13 +220,15 @@ class UpscalerItem(BaseModel):
|
|||
model_name: Optional[str] = Field(title="Model Name")
|
||||
model_path: Optional[str] = Field(title="Path")
|
||||
model_url: Optional[str] = Field(title="URL")
|
||||
scale: Optional[float] = Field(title="Scale")
|
||||
|
||||
class SDModelItem(BaseModel):
|
||||
title: str = Field(title="Title")
|
||||
model_name: str = Field(title="Model Name")
|
||||
hash: str = Field(title="Hash")
|
||||
hash: Optional[str] = Field(title="Short hash")
|
||||
sha256: Optional[str] = Field(title="sha256 hash")
|
||||
filename: str = Field(title="Filename")
|
||||
config: str = Field(title="Config file")
|
||||
config: Optional[str] = Field(title="Config file")
|
||||
|
||||
class HypernetworkItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
|
@ -260,3 +263,7 @@ class EmbeddingItem(BaseModel):
|
|||
class EmbeddingsResponse(BaseModel):
|
||||
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
ram: dict = Field(title="RAM", description="System memory stats")
|
||||
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
import os.path
|
||||
import csv
|
||||
from collections import namedtuple
|
||||
|
||||
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
|
||||
|
||||
|
||||
class ArtistsDatabase:
|
||||
def __init__(self, filename):
|
||||
self.cats = set()
|
||||
self.artists = []
|
||||
|
||||
if not os.path.exists(filename):
|
||||
return
|
||||
|
||||
with open(filename, "r", newline='', encoding="utf8") as file:
|
||||
reader = csv.DictReader(file)
|
||||
|
||||
for row in reader:
|
||||
artist = Artist(row["artist"], float(row["score"]), row["category"])
|
||||
self.artists.append(artist)
|
||||
self.cats.add(artist.category)
|
||||
|
||||
def categories(self):
|
||||
return sorted(self.cats)
|
|
@ -4,7 +4,7 @@ import threading
|
|||
import traceback
|
||||
import time
|
||||
|
||||
from modules import shared
|
||||
from modules import shared, progress
|
||||
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
@ -22,12 +22,23 @@ def wrap_queued_call(func):
|
|||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
def f(*args, **kwargs):
|
||||
|
||||
shared.state.begin()
|
||||
# if the first argument is a string that says "task(...)", it is treated as a job id
|
||||
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
|
||||
id_task = args[0]
|
||||
progress.add_task_to_queue(id_task)
|
||||
else:
|
||||
id_task = None
|
||||
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
shared.state.begin()
|
||||
progress.start_task(id_task)
|
||||
|
||||
shared.state.end()
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
finally:
|
||||
progress.finish_task(id_task)
|
||||
|
||||
shared.state.end()
|
||||
|
||||
return res
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import modules.face_restoration
|
||||
import modules.shared
|
||||
from modules import shared, devices, modelloader
|
||||
from modules.paths import script_path, models_path
|
||||
from modules.paths import models_path
|
||||
|
||||
# codeformer people made a choice to include modified basicsr library to their project which makes
|
||||
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
||||
|
|
|
@ -2,6 +2,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules import devices
|
||||
|
||||
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
||||
|
||||
|
||||
|
@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module):
|
|||
t_358, = inputs
|
||||
t_359 = t_358.permute(*[0, 3, 1, 2])
|
||||
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
||||
t_360 = self.n_Conv_0(t_359_padded)
|
||||
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
|
||||
t_361 = F.relu(t_360)
|
||||
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
||||
t_362 = self.n_MaxPool_0(t_361)
|
||||
|
|
|
@ -34,14 +34,18 @@ def get_cuda_device_string():
|
|||
return "cuda"
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
def get_optimal_device_name():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device(get_cuda_device_string())
|
||||
return get_cuda_device_string()
|
||||
|
||||
if has_mps():
|
||||
return torch.device("mps")
|
||||
return "mps"
|
||||
|
||||
return cpu
|
||||
return "cpu"
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
return torch.device(get_optimal_device_name())
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
|
@ -79,6 +83,16 @@ cpu = torch.device("cpu")
|
|||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
dtype_unet = torch.float16
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
def cond_cast_unet(input):
|
||||
return input.to(dtype_unet) if unet_needs_upcast else input
|
||||
|
||||
|
||||
def cond_cast_float(input):
|
||||
return input.float() if unet_needs_upcast else input
|
||||
|
||||
|
||||
def randn(seed, shape):
|
||||
|
@ -106,6 +120,42 @@ def autocast(disable=False):
|
|||
return torch.autocast("cuda")
|
||||
|
||||
|
||||
def without_autocast(disable=False):
|
||||
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
||||
|
||||
|
||||
class NansException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def test_for_nans(x, where):
|
||||
from modules import shared
|
||||
|
||||
if shared.cmd_opts.disable_nan_check:
|
||||
return
|
||||
|
||||
if not torch.all(torch.isnan(x)).item():
|
||||
return
|
||||
|
||||
if where == "unet":
|
||||
message = "A tensor with all NaNs was produced in Unet."
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
|
||||
|
||||
elif where == "vae":
|
||||
message = "A tensor with all NaNs was produced in VAE."
|
||||
|
||||
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
|
||||
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
|
||||
else:
|
||||
message = "A tensor with all NaNs was produced."
|
||||
|
||||
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||
|
||||
raise NansException(message)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
orig_tensor_to = torch.Tensor.to
|
||||
def tensor_to_fix(self, *args, **kwargs):
|
||||
|
@ -139,8 +189,10 @@ orig_Tensor_cumsum = torch.Tensor.cumsum
|
|||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
if input.device.type == 'mps':
|
||||
output_dtype = kwargs.get('dtype', input.dtype)
|
||||
if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
|
||||
if output_dtype == torch.int64:
|
||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||
return cumsum_func(input, *args, **kwargs)
|
||||
|
||||
|
||||
|
@ -151,8 +203,7 @@ if has_mps():
|
|||
torch.nn.functional.layer_norm = layer_norm_fix
|
||||
torch.Tensor.numpy = numpy_fix
|
||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||
if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
|
||||
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||
orig_narrow = torch.narrow
|
||||
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
|
||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||
|
|
|
@ -19,11 +19,23 @@ def display(e: Exception, task):
|
|||
message = str(e)
|
||||
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
||||
print_error_explanation("""
|
||||
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
|
||||
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
|
||||
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
||||
""")
|
||||
|
||||
|
||||
already_displayed = {}
|
||||
|
||||
|
||||
def display_once(e: Exception, task):
|
||||
if task in already_displayed:
|
||||
return
|
||||
|
||||
display(e, task)
|
||||
|
||||
already_displayed[task] = 1
|
||||
|
||||
|
||||
def run(code, task):
|
||||
try:
|
||||
code()
|
||||
|
|
|
@ -7,9 +7,11 @@ import git
|
|||
from modules import paths, shared
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||
extensions_dir = os.path.join(paths.data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||
|
||||
if not os.path.exists(extensions_dir):
|
||||
os.makedirs(extensions_dir)
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
|
147
modules/extra_networks.py
Normal file
147
modules/extra_networks.py
Normal file
|
@ -0,0 +1,147 @@
|
|||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from modules import errors
|
||||
|
||||
extra_network_registry = {}
|
||||
|
||||
|
||||
def initialize():
|
||||
extra_network_registry.clear()
|
||||
|
||||
|
||||
def register_extra_network(extra_network):
|
||||
extra_network_registry[extra_network.name] = extra_network
|
||||
|
||||
|
||||
class ExtraNetworkParams:
|
||||
def __init__(self, items=None):
|
||||
self.items = items or []
|
||||
|
||||
|
||||
class ExtraNetwork:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def activate(self, p, params_list):
|
||||
"""
|
||||
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
|
||||
Passes arguments related to this extra network in params_list.
|
||||
User passes arguments by specifying this in his prompt:
|
||||
|
||||
<name:arg1:arg2:arg3>
|
||||
|
||||
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
||||
separated by colon.
|
||||
|
||||
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
|
||||
in this case, all effects of this extra networks should be disabled.
|
||||
|
||||
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
||||
|
||||
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
|
||||
|
||||
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
|
||||
|
||||
params_list will be:
|
||||
|
||||
[
|
||||
ExtraNetworkParams(items=["agm", "1.1"]),
|
||||
ExtraNetworkParams(items=["ray"])
|
||||
]
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def deactivate(self, p):
|
||||
"""
|
||||
Called at the end of processing for housekeeping. No need to do anything here.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def activate(p, extra_network_data):
|
||||
"""call activate for extra networks in extra_network_data in specified order, then call
|
||||
activate for all remaining registered networks with an empty argument list"""
|
||||
|
||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
if extra_network is None:
|
||||
print(f"Skipping unknown extra network: {extra_network_name}")
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.activate(p, extra_network_args)
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
||||
|
||||
for extra_network_name, extra_network in extra_network_registry.items():
|
||||
args = extra_network_data.get(extra_network_name, None)
|
||||
if args is not None:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.activate(p, [])
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network_name}")
|
||||
|
||||
|
||||
def deactivate(p, extra_network_data):
|
||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||
deactivate for all remaining registered networks"""
|
||||
|
||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
if extra_network is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.deactivate(p)
|
||||
except Exception as e:
|
||||
errors.display(e, f"deactivating extra network {extra_network_name}")
|
||||
|
||||
for extra_network_name, extra_network in extra_network_registry.items():
|
||||
args = extra_network_data.get(extra_network_name, None)
|
||||
if args is not None:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.deactivate(p)
|
||||
except Exception as e:
|
||||
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
|
||||
|
||||
|
||||
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
|
||||
|
||||
|
||||
def parse_prompt(prompt):
|
||||
res = defaultdict(list)
|
||||
|
||||
def found(m):
|
||||
name = m.group(1)
|
||||
args = m.group(2)
|
||||
|
||||
res[name].append(ExtraNetworkParams(items=args.split(":")))
|
||||
|
||||
return ""
|
||||
|
||||
prompt = re.sub(re_extra_net, found, prompt)
|
||||
|
||||
return prompt, res
|
||||
|
||||
|
||||
def parse_prompts(prompts):
|
||||
res = []
|
||||
extra_data = None
|
||||
|
||||
for prompt in prompts:
|
||||
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
||||
|
||||
if extra_data is None:
|
||||
extra_data = parsed_extra_data
|
||||
|
||||
res.append(updated_prompt)
|
||||
|
||||
return res, extra_data
|
||||
|
27
modules/extra_networks_hypernet.py
Normal file
27
modules/extra_networks_hypernet.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from modules import extra_networks, shared, extra_networks
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
|
||||
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
||||
def __init__(self):
|
||||
super().__init__('hypernet')
|
||||
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_hypernetwork
|
||||
|
||||
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
|
||||
names = []
|
||||
multipliers = []
|
||||
for params in params_list:
|
||||
assert len(params.items) > 0
|
||||
|
||||
names.append(params.items[0])
|
||||
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
||||
|
||||
hypernetwork.load_hypernetworks(names, multipliers)
|
||||
|
||||
def deactivate(self, p):
|
||||
pass
|
|
@ -1,229 +1,16 @@
|
|||
from __future__ import annotations
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from typing import Callable, List, OrderedDict, Tuple
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import processing, shared, images, devices, sd_models, sd_samplers
|
||||
from modules.shared import opts
|
||||
import modules.gfpgan_model
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.codeformer_model
|
||||
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
||||
from modules.ui_common import plaintext_to_html
|
||||
import gradio as gr
|
||||
import safetensors.torch
|
||||
|
||||
class LruCache(OrderedDict):
|
||||
@dataclass(frozen=True)
|
||||
class Key:
|
||||
image_hash: int
|
||||
info_hash: int
|
||||
args_hash: int
|
||||
|
||||
@dataclass
|
||||
class Value:
|
||||
image: Image.Image
|
||||
info: str
|
||||
|
||||
def __init__(self, max_size: int = 5, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._max_size = max_size
|
||||
|
||||
def get(self, key: LruCache.Key) -> LruCache.Value:
|
||||
ret = super().get(key)
|
||||
if ret is not None:
|
||||
self.move_to_end(key) # Move to end of eviction list
|
||||
return ret
|
||||
|
||||
def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
|
||||
self[key] = value
|
||||
while len(self) > self._max_size:
|
||||
self.popitem(last=False)
|
||||
|
||||
|
||||
cached_images: LruCache = LruCache(max_size=5)
|
||||
|
||||
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'extras'
|
||||
|
||||
imageArr = []
|
||||
# Also keep track of original file names
|
||||
imageNameArr = []
|
||||
outputs = []
|
||||
|
||||
if extras_mode == 1:
|
||||
#convert file to pillow image
|
||||
for img in image_folder:
|
||||
image = Image.open(img)
|
||||
imageArr.append(image)
|
||||
imageNameArr.append(os.path.splitext(img.orig_name)[0])
|
||||
elif extras_mode == 2:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||
|
||||
if input_dir == '':
|
||||
return outputs, "Please select an input directory.", ''
|
||||
image_list = shared.listfiles(input_dir)
|
||||
for img in image_list:
|
||||
try:
|
||||
image = Image.open(img)
|
||||
except Exception:
|
||||
continue
|
||||
imageArr.append(image)
|
||||
imageNameArr.append(img)
|
||||
else:
|
||||
imageArr.append(image)
|
||||
imageNameArr.append(None)
|
||||
|
||||
if extras_mode == 2 and output_dir != '':
|
||||
outpath = output_dir
|
||||
else:
|
||||
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
||||
|
||||
# Extra operation definitions
|
||||
|
||||
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
shared.state.job = 'extras-gfpgan'
|
||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if gfpgan_visibility < 1.0:
|
||||
res = Image.blend(image, res, gfpgan_visibility)
|
||||
|
||||
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
|
||||
return (res, info)
|
||||
|
||||
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
shared.state.job = 'extras-codeformer'
|
||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if codeformer_visibility < 1.0:
|
||||
res = Image.blend(image, res, codeformer_visibility)
|
||||
|
||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||
return (res, info)
|
||||
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
shared.state.job = 'extras-upscale'
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
if mode == 1 and crop:
|
||||
cropped = Image.new("RGB", (resize_w, resize_h))
|
||||
cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
|
||||
res = cropped
|
||||
return res
|
||||
|
||||
def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
|
||||
nonlocal upscaling_resize
|
||||
if resize_mode == 1:
|
||||
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
|
||||
crop_info = " (crop)" if upscaling_crop else ""
|
||||
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
|
||||
return (image, info)
|
||||
|
||||
@dataclass
|
||||
class UpscaleParams:
|
||||
upscaler_idx: int
|
||||
blend_alpha: float
|
||||
|
||||
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
blended_result: Image.Image = None
|
||||
image_hash: str = hash(np.array(image.getdata()).tobytes())
|
||||
for upscaler in params:
|
||||
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
|
||||
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
cache_key = LruCache.Key(image_hash=image_hash,
|
||||
info_hash=hash(info),
|
||||
args_hash=hash(upscale_args))
|
||||
cached_entry = cached_images.get(cache_key)
|
||||
if cached_entry is None:
|
||||
res = upscale(image, *upscale_args)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
|
||||
cached_images.put(cache_key, LruCache.Value(image=res, info=info))
|
||||
else:
|
||||
res, info = cached_entry.image, cached_entry.info
|
||||
|
||||
if blended_result is None:
|
||||
blended_result = res
|
||||
else:
|
||||
blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
|
||||
return (blended_result, info)
|
||||
|
||||
# Build a list of operations to run
|
||||
facefix_ops: List[Callable] = []
|
||||
facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
|
||||
facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
|
||||
|
||||
upscale_ops: List[Callable] = []
|
||||
upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
|
||||
|
||||
if upscaling_resize != 0:
|
||||
step_params: List[UpscaleParams] = []
|
||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
|
||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
|
||||
|
||||
upscale_ops.append(partial(run_upscalers_blend, step_params))
|
||||
|
||||
extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
|
||||
|
||||
for image, image_name in zip(imageArr, imageNameArr):
|
||||
if image is None:
|
||||
return outputs, "Please select an input image.", ''
|
||||
|
||||
shared.state.textinfo = f'Processing image {image_name}'
|
||||
|
||||
existing_pnginfo = image.info or {}
|
||||
|
||||
image = image.convert("RGB")
|
||||
info = ""
|
||||
# Run each operation on each image
|
||||
for op in extras_ops:
|
||||
image, info = op(image, info)
|
||||
|
||||
if opts.use_original_name_batch and image_name is not None:
|
||||
basename = os.path.splitext(os.path.basename(image_name))[0]
|
||||
else:
|
||||
basename = ''
|
||||
|
||||
if opts.enable_pnginfo: # append info before save
|
||||
image.info = existing_pnginfo
|
||||
image.info["extras"] = info
|
||||
|
||||
if save_output:
|
||||
# Add upscaler name as a suffix.
|
||||
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
|
||||
# Add second upscaler if applicable.
|
||||
if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
|
||||
suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
|
||||
|
||||
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
|
||||
|
||||
if extras_mode != 2 or show_extras_results :
|
||||
outputs.append(image)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
return outputs, plaintext_to_html(info), ''
|
||||
|
||||
def clear_cache():
|
||||
cached_images.clear()
|
||||
|
||||
|
||||
def run_pnginfo(image):
|
||||
if image is None:
|
||||
|
@ -248,10 +35,51 @@ def run_pnginfo(image):
|
|||
return '', geninfo, info
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
||||
def create_config(ckpt_result, config_source, a, b, c):
|
||||
def config(x):
|
||||
res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
|
||||
return res if res != shared.sd_default_config else None
|
||||
|
||||
if config_source == 0:
|
||||
cfg = config(a) or config(b) or config(c)
|
||||
elif config_source == 1:
|
||||
cfg = config(b)
|
||||
elif config_source == 2:
|
||||
cfg = config(c)
|
||||
else:
|
||||
cfg = None
|
||||
|
||||
if cfg is None:
|
||||
return
|
||||
|
||||
filename, _ = os.path.splitext(ckpt_result)
|
||||
checkpoint_filename = filename + ".yaml"
|
||||
|
||||
print("Copying config:")
|
||||
print(" from:", cfg)
|
||||
print(" to:", checkpoint_filename)
|
||||
shutil.copyfile(cfg, checkpoint_filename)
|
||||
|
||||
|
||||
checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||
|
||||
|
||||
def to_half(tensor, enable):
|
||||
if enable and tensor.dtype == torch.float:
|
||||
return tensor.half()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
|
||||
def fail(message):
|
||||
shared.state.textinfo = message
|
||||
shared.state.end()
|
||||
return [*[gr.update() for _ in range(4)], message]
|
||||
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
||||
|
@ -261,91 +89,156 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||
def add_difference(theta0, theta1_2_diff, alpha):
|
||||
return theta0 + (alpha * theta1_2_diff)
|
||||
|
||||
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
||||
tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
|
||||
result_is_inpainting_model = False
|
||||
def filename_weighted_sum():
|
||||
a = primary_model_info.model_name
|
||||
b = secondary_model_info.model_name
|
||||
Ma = round(1 - multiplier, 2)
|
||||
Mb = round(multiplier, 2)
|
||||
|
||||
return f"{Ma}({a}) + {Mb}({b})"
|
||||
|
||||
def filename_add_difference():
|
||||
a = primary_model_info.model_name
|
||||
b = secondary_model_info.model_name
|
||||
c = tertiary_model_info.model_name
|
||||
M = round(multiplier, 2)
|
||||
|
||||
return f"{a} + {M}({b} - {c})"
|
||||
|
||||
def filename_nothing():
|
||||
return primary_model_info.model_name
|
||||
|
||||
theta_funcs = {
|
||||
"Weighted sum": (None, weighted_sum),
|
||||
"Add difference": (get_difference, add_difference),
|
||||
"Weighted sum": (filename_weighted_sum, None, weighted_sum),
|
||||
"Add difference": (filename_add_difference, get_difference, add_difference),
|
||||
"No interpolation": (filename_nothing, None, None),
|
||||
}
|
||||
theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
|
||||
|
||||
if theta_func1 and not tertiary_model_info:
|
||||
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
|
||||
shared.state.end()
|
||||
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
if not primary_model_name:
|
||||
return fail("Failed: Merging requires a primary model.")
|
||||
|
||||
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||
|
||||
if theta_func2 and not secondary_model_name:
|
||||
return fail("Failed: Merging requires a secondary model.")
|
||||
|
||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
|
||||
|
||||
if theta_func1 and not tertiary_model_name:
|
||||
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
||||
|
||||
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
||||
|
||||
result_is_inpainting_model = False
|
||||
result_is_instruct_pix2pix_model = False
|
||||
|
||||
if theta_func2:
|
||||
shared.state.textinfo = f"Loading B"
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
else:
|
||||
theta_1 = None
|
||||
|
||||
if theta_func1:
|
||||
shared.state.textinfo = f"Loading C"
|
||||
print(f"Loading {tertiary_model_info.filename}...")
|
||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||
|
||||
shared.state.textinfo = 'Merging B and C'
|
||||
shared.state.sampling_steps = len(theta_1.keys())
|
||||
for key in tqdm.tqdm(theta_1.keys()):
|
||||
if key in checkpoint_dict_skip_on_merge:
|
||||
continue
|
||||
|
||||
if 'model' in key:
|
||||
if key in theta_2:
|
||||
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
||||
theta_1[key] = theta_func1(theta_1[key], t2)
|
||||
else:
|
||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
del theta_2
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||
|
||||
print("Merging...")
|
||||
|
||||
shared.state.textinfo = 'Merging A and B'
|
||||
shared.state.sampling_steps = len(theta_0.keys())
|
||||
for key in tqdm.tqdm(theta_0.keys()):
|
||||
if 'model' in key and key in theta_1:
|
||||
if theta_1 and 'model' in key and key in theta_1:
|
||||
|
||||
if key in checkpoint_dict_skip_on_merge:
|
||||
continue
|
||||
|
||||
a = theta_0[key]
|
||||
b = theta_1[key]
|
||||
|
||||
shared.state.textinfo = f'Merging layer {key}'
|
||||
# this enables merging an inpainting model (A) with another one (B);
|
||||
# where normal model would have 4 channels, for latenst space, inpainting model would
|
||||
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
||||
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
|
||||
if a.shape[1] == 4 and b.shape[1] == 9:
|
||||
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
|
||||
if a.shape[1] == 4 and b.shape[1] == 8:
|
||||
raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
|
||||
|
||||
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
||||
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||
result_is_inpainting_model = True
|
||||
if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
|
||||
result_is_instruct_pix2pix_model = True
|
||||
else:
|
||||
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||
result_is_inpainting_model = True
|
||||
else:
|
||||
theta_0[key] = theta_func2(a, b, multiplier)
|
||||
|
||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||
|
||||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
shared.state.sampling_step += 1
|
||||
|
||||
# I believe this part should be discarded, but I'll leave it for now until I am sure
|
||||
for key in theta_1.keys():
|
||||
if 'model' in key and key not in theta_0:
|
||||
theta_0[key] = theta_1[key]
|
||||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
del theta_1
|
||||
|
||||
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
||||
if bake_in_vae_filename is not None:
|
||||
print(f"Baking in VAE from {bake_in_vae_filename}")
|
||||
shared.state.textinfo = 'Baking in VAE'
|
||||
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
|
||||
|
||||
for key in vae_dict.keys():
|
||||
theta_0_key = 'first_stage_model.' + key
|
||||
if theta_0_key in theta_0:
|
||||
theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
|
||||
|
||||
del vae_dict
|
||||
|
||||
if save_as_half and not theta_func2:
|
||||
for key in theta_0.keys():
|
||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||
|
||||
if discard_weights:
|
||||
regex = re.compile(discard_weights)
|
||||
for key in list(theta_0):
|
||||
if re.search(regex, key):
|
||||
theta_0.pop(key, None)
|
||||
|
||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||
|
||||
filename = \
|
||||
primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
|
||||
secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
|
||||
interp_method.replace(" ", "_") + \
|
||||
'-merged.' + \
|
||||
("inpainting." if result_is_inpainting_model else "") + \
|
||||
checkpoint_format
|
||||
|
||||
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
|
||||
filename = filename_generator() if custom_name == '' else custom_name
|
||||
filename += ".inpainting" if result_is_inpainting_model else ""
|
||||
filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
|
||||
filename += "." + checkpoint_format
|
||||
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||
shared.state.nextjob()
|
||||
shared.state.textinfo = "Saving"
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
|
@ -356,8 +249,10 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||
|
||||
sd_models.list_models()
|
||||
|
||||
print("Checkpoint saved.")
|
||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||
|
||||
print(f"Checkpoint saved to {output_modelname}.")
|
||||
shared.state.textinfo = "Checkpoint saved"
|
||||
shared.state.end()
|
||||
|
||||
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
||||
|
|
|
@ -6,14 +6,13 @@ import re
|
|||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from modules.shared import script_path
|
||||
from modules import shared, ui_tempdir, sd_vae
|
||||
from modules.paths import data_path
|
||||
from modules import shared, ui_tempdir, script_callbacks, sd_vae
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
||||
type_of_gr_update = type(gr.update())
|
||||
|
@ -37,6 +36,9 @@ def quote(text):
|
|||
|
||||
|
||||
def image_from_url_text(filedata):
|
||||
if filedata is None:
|
||||
return None
|
||||
|
||||
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
|
||||
filedata = filedata[0]
|
||||
|
||||
|
@ -76,8 +78,6 @@ def integrate_settings_paste_fields(component_dict):
|
|||
from modules import ui
|
||||
|
||||
settings_map = {
|
||||
'sd_hypernetwork': 'Hypernet',
|
||||
'sd_hypernetwork_strength': 'Hypernet strength',
|
||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||
'inpainting_mask_weight': 'Conditional mask weight',
|
||||
'sd_model_checkpoint': 'Model hash',
|
||||
|
@ -198,6 +198,15 @@ def restore_old_hires_fix_params(res):
|
|||
firstpass_width = res.get('First pass size-1', None)
|
||||
firstpass_height = res.get('First pass size-2', None)
|
||||
|
||||
if shared.opts.use_old_hires_fix_width_height:
|
||||
hires_width = int(res.get("Hires resize-1", 0))
|
||||
hires_height = int(res.get("Hires resize-2", 0))
|
||||
|
||||
if hires_width and hires_height:
|
||||
res['Size-1'] = hires_width
|
||||
res['Size-2'] = hires_height
|
||||
return
|
||||
|
||||
if firstpass_width is None or firstpass_height is None:
|
||||
return
|
||||
|
||||
|
@ -206,12 +215,8 @@ def restore_old_hires_fix_params(res):
|
|||
height = int(res.get("Size-2", 512))
|
||||
|
||||
if firstpass_width == 0 or firstpass_height == 0:
|
||||
# old algorithm for auto-calculating first pass size
|
||||
desired_pixel_count = 512 * 512
|
||||
actual_pixel_count = width * height
|
||||
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
||||
firstpass_width = math.ceil(scale * width / 64) * 64
|
||||
firstpass_height = math.ceil(scale * height / 64) * 64
|
||||
from modules import processing
|
||||
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
||||
|
||||
res['Size-1'] = firstpass_width
|
||||
res['Size-2'] = firstpass_height
|
||||
|
@ -238,7 +243,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
done_with_prompt = False
|
||||
|
||||
*lines, lastline = x.strip().split("\n")
|
||||
if not re_params.match(lastline):
|
||||
if len(re_param.findall(lastline)) < 3:
|
||||
lines.append(lastline)
|
||||
lastline = ''
|
||||
|
||||
|
@ -257,6 +262,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
res["Negative prompt"] = negative_prompt
|
||||
|
||||
for k, v in re_param.findall(lastline):
|
||||
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
||||
m = re_imagesize.match(v)
|
||||
if m is not None:
|
||||
res[k+"-1"] = m.group(1)
|
||||
|
@ -268,13 +274,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
if "Clip skip" not in res:
|
||||
res["Clip skip"] = "1"
|
||||
|
||||
if "Hypernet strength" not in res:
|
||||
res["Hypernet strength"] = "1"
|
||||
|
||||
if "Hypernet" in res:
|
||||
hypernet_name = res["Hypernet"]
|
||||
hypernet_hash = res.get("Hypernet hash", None)
|
||||
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
|
||||
hypernet = res.get("Hypernet", None)
|
||||
if hypernet is not None:
|
||||
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
||||
|
||||
if "Hires resize-1" not in res:
|
||||
res["Hires resize-1"] = 0
|
||||
|
@ -293,12 +295,13 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
||||
def paste_func(prompt):
|
||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
||||
filename = os.path.join(script_path, "params.txt")
|
||||
filename = os.path.join(data_path, "params.txt")
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
prompt = file.read()
|
||||
|
||||
params = parse_generation_parameters(prompt)
|
||||
script_callbacks.infotext_pasted_callback(prompt, params)
|
||||
res = []
|
||||
|
||||
for output, key in paste_fields:
|
||||
|
|
|
@ -6,12 +6,11 @@ import facexlib
|
|||
import gfpgan
|
||||
|
||||
import modules.face_restoration
|
||||
from modules import shared, devices, modelloader
|
||||
from modules.paths import models_path
|
||||
from modules import paths, shared, devices, modelloader
|
||||
|
||||
model_dir = "GFPGAN"
|
||||
user_path = None
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_path = os.path.join(paths.models_path, model_dir)
|
||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
have_gfpgan = False
|
||||
loaded_gfpgan_model = None
|
||||
|
|
87
modules/hashes.py
Normal file
87
modules/hashes.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
import hashlib
|
||||
import json
|
||||
import os.path
|
||||
|
||||
import filelock
|
||||
|
||||
from modules.paths import data_path
|
||||
|
||||
|
||||
cache_filename = os.path.join(data_path, "cache.json")
|
||||
cache_data = None
|
||||
|
||||
|
||||
def dump_cache():
|
||||
with filelock.FileLock(cache_filename+".lock"):
|
||||
with open(cache_filename, "w", encoding="utf8") as file:
|
||||
json.dump(cache_data, file, indent=4)
|
||||
|
||||
|
||||
def cache(subsection):
|
||||
global cache_data
|
||||
|
||||
if cache_data is None:
|
||||
with filelock.FileLock(cache_filename+".lock"):
|
||||
if not os.path.isfile(cache_filename):
|
||||
cache_data = {}
|
||||
else:
|
||||
with open(cache_filename, "r", encoding="utf8") as file:
|
||||
cache_data = json.load(file)
|
||||
|
||||
s = cache_data.get(subsection, {})
|
||||
cache_data[subsection] = s
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def calculate_sha256(filename):
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def sha256_from_cache(filename, title):
|
||||
hashes = cache("hashes")
|
||||
ondisk_mtime = os.path.getmtime(filename)
|
||||
|
||||
if title not in hashes:
|
||||
return None
|
||||
|
||||
cached_sha256 = hashes[title].get("sha256", None)
|
||||
cached_mtime = hashes[title].get("mtime", 0)
|
||||
|
||||
if ondisk_mtime > cached_mtime or cached_sha256 is None:
|
||||
return None
|
||||
|
||||
return cached_sha256
|
||||
|
||||
|
||||
def sha256(filename, title):
|
||||
hashes = cache("hashes")
|
||||
|
||||
sha256_value = sha256_from_cache(filename, title)
|
||||
if sha256_value is not None:
|
||||
return sha256_value
|
||||
|
||||
print(f"Calculating sha256 for {filename}: ", end='')
|
||||
sha256_value = calculate_sha256(filename)
|
||||
print(f"{sha256_value}")
|
||||
|
||||
hashes[title] = {
|
||||
"mtime": os.path.getmtime(filename),
|
||||
"sha256": sha256_value,
|
||||
}
|
||||
|
||||
dump_cache()
|
||||
|
||||
return sha256_value
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ import torch
|
|||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
|
||||
from modules.textual_inversion import textual_inversion, logging
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
|
@ -25,7 +25,6 @@ from statistics import stdev, mean
|
|||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
|
@ -38,9 +37,11 @@ class HypernetworkModule(torch.nn.Module):
|
|||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
|
||||
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
||||
super().__init__()
|
||||
|
||||
self.multiplier = 1.0
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
|
@ -63,9 +64,12 @@ class HypernetworkModule(torch.nn.Module):
|
|||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add dropout except last layer
|
||||
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
|
||||
linears.append(torch.nn.Dropout(p=0.3))
|
||||
# Everything should be now parsed into dropout structure, and applied here.
|
||||
# Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
|
||||
if dropout_structure is not None and dropout_structure[i+1] > 0:
|
||||
assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
|
||||
linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
|
||||
# Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
|
||||
|
@ -112,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||
state_dict[to] = x
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.linear(x) * self.multiplier
|
||||
return x + self.linear(x) * (self.multiplier if not self.training else 1)
|
||||
|
||||
def trainables(self):
|
||||
layer_structure = []
|
||||
|
@ -122,8 +126,20 @@ class HypernetworkModule(torch.nn.Module):
|
|||
return layer_structure
|
||||
|
||||
|
||||
def apply_strength(value=None):
|
||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
||||
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
||||
if layer_structure is None:
|
||||
layer_structure = [1, 2, 1]
|
||||
if not use_dropout:
|
||||
return [0] * len(layer_structure)
|
||||
dropout_values = [0]
|
||||
dropout_values.extend([0.3] * (len(layer_structure) - 3))
|
||||
if last_layer_dropout:
|
||||
dropout_values.append(0.3)
|
||||
else:
|
||||
dropout_values.append(0)
|
||||
dropout_values.append(0)
|
||||
return dropout_values
|
||||
|
||||
|
||||
class Hypernetwork:
|
||||
|
@ -143,18 +159,22 @@ class Hypernetwork:
|
|||
self.add_layer_norm = add_layer_norm
|
||||
self.use_dropout = use_dropout
|
||||
self.activate_output = activate_output
|
||||
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
|
||||
self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
|
||||
self.dropout_structure = kwargs.get('dropout_structure', None)
|
||||
if self.dropout_structure is None:
|
||||
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
||||
self.optimizer_name = None
|
||||
self.optimizer_state_dict = None
|
||||
self.optional_info = None
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
||||
)
|
||||
self.eval_mode()
|
||||
self.eval()
|
||||
|
||||
def weights(self):
|
||||
res = []
|
||||
|
@ -163,14 +183,28 @@ class Hypernetwork:
|
|||
res += layer.parameters()
|
||||
return res
|
||||
|
||||
def train_mode(self):
|
||||
def train(self, mode=True):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.train()
|
||||
layer.train(mode=mode)
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = True
|
||||
param.requires_grad = mode
|
||||
|
||||
def eval_mode(self):
|
||||
def to(self, device):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.multiplier = multiplier
|
||||
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.eval()
|
||||
|
@ -190,18 +224,20 @@ class Hypernetwork:
|
|||
state_dict['activation_func'] = self.activation_func
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['weight_initialization'] = self.weight_init
|
||||
state_dict['use_dropout'] = self.use_dropout
|
||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
state_dict['activate_output'] = self.activate_output
|
||||
state_dict['last_layer_dropout'] = self.last_layer_dropout
|
||||
state_dict['use_dropout'] = self.use_dropout
|
||||
state_dict['dropout_structure'] = self.dropout_structure
|
||||
state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
|
||||
state_dict['optional_info'] = self.optional_info if self.optional_info else None
|
||||
|
||||
if self.optimizer_name is not None:
|
||||
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
||||
|
||||
torch.save(state_dict, filename)
|
||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
||||
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
|
||||
optimizer_saved_dict['hash'] = self.shorthash()
|
||||
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||
|
||||
|
@ -213,44 +249,65 @@ class Hypernetwork:
|
|||
state_dict = torch.load(filename, map_location='cpu')
|
||||
|
||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
print(self.layer_structure)
|
||||
self.optional_info = state_dict.get('optional_info', None)
|
||||
self.activation_func = state_dict.get('activation_func', None)
|
||||
print(f"Activation function is {self.activation_func}")
|
||||
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||
print(f"Weight initialization is {self.weight_init}")
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
print(f"Layer norm is set to {self.add_layer_norm}")
|
||||
self.use_dropout = state_dict.get('use_dropout', False)
|
||||
print(f"Dropout usage is set to {self.use_dropout}" )
|
||||
self.dropout_structure = state_dict.get('dropout_structure', None)
|
||||
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
||||
self.activate_output = state_dict.get('activate_output', True)
|
||||
print(f"Activate last layer is set to {self.activate_output}")
|
||||
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
||||
if self.dropout_structure is None:
|
||||
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
||||
|
||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
||||
if shared.opts.print_hypernet_extra:
|
||||
if self.optional_info is not None:
|
||||
print(f" INFO:\n {self.optional_info}\n")
|
||||
|
||||
print(f" Layer structure: {self.layer_structure}")
|
||||
print(f" Activation function: {self.activation_func}")
|
||||
print(f" Weight initialization: {self.weight_init}")
|
||||
print(f" Layer norm: {self.add_layer_norm}")
|
||||
print(f" Dropout usage: {self.use_dropout}" )
|
||||
print(f" Activate last layer: {self.activate_output}")
|
||||
print(f" Dropout structure: {self.dropout_structure}")
|
||||
|
||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||
|
||||
if self.shorthash() == optimizer_saved_dict.get('hash', None):
|
||||
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
else:
|
||||
self.optimizer_state_dict = None
|
||||
if self.optimizer_state_dict:
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||
if shared.opts.print_hypernet_extra:
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
else:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
self.optimizer_name = "AdamW"
|
||||
if shared.opts.print_hypernet_extra:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
self.step = state_dict.get('step', 0)
|
||||
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
||||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||
self.eval()
|
||||
|
||||
def shorthash(self):
|
||||
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
||||
|
||||
return sha256[0:10]
|
||||
|
||||
|
||||
def list_hypernetworks(path):
|
||||
|
@ -259,27 +316,47 @@ def list_hypernetworks(path):
|
|||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
# Prevent a hypothetical "None.pt" from being listed.
|
||||
if name != "None":
|
||||
res[name + f"({sd_models.model_hash(filename)})"] = filename
|
||||
res[name] = filename
|
||||
return res
|
||||
|
||||
|
||||
def load_hypernetwork(filename):
|
||||
path = shared.hypernetworks.get(filename, None)
|
||||
# Prevent any file named "None.pt" from being loaded.
|
||||
if path is not None and filename != "None":
|
||||
print(f"Loading hypernetwork {filename}")
|
||||
try:
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
def load_hypernetwork(name):
|
||||
path = shared.hypernetworks.get(name, None)
|
||||
|
||||
except Exception:
|
||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
if shared.loaded_hypernetwork is not None:
|
||||
print("Unloading hypernetwork")
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
shared.loaded_hypernetwork = None
|
||||
hypernetwork = Hypernetwork()
|
||||
|
||||
try:
|
||||
hypernetwork.load(path)
|
||||
except Exception:
|
||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
return hypernetwork
|
||||
|
||||
|
||||
def load_hypernetworks(names, multipliers=None):
|
||||
already_loaded = {}
|
||||
|
||||
for hypernetwork in shared.loaded_hypernetworks:
|
||||
if hypernetwork.name in names:
|
||||
already_loaded[hypernetwork.name] = hypernetwork
|
||||
|
||||
shared.loaded_hypernetworks.clear()
|
||||
|
||||
for i, name in enumerate(names):
|
||||
hypernetwork = already_loaded.get(name, None)
|
||||
if hypernetwork is None:
|
||||
hypernetwork = load_hypernetwork(name)
|
||||
|
||||
if hypernetwork is None:
|
||||
continue
|
||||
|
||||
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
|
||||
shared.loaded_hypernetworks.append(hypernetwork)
|
||||
|
||||
|
||||
def find_closest_hypernetwork_name(search: str):
|
||||
|
@ -293,18 +370,27 @@ def find_closest_hypernetwork_name(search: str):
|
|||
return applicable[0]
|
||||
|
||||
|
||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||
|
||||
if hypernetwork_layers is None:
|
||||
return context, context
|
||||
return context_k, context_v
|
||||
|
||||
if layer is not None:
|
||||
layer.hyper_k = hypernetwork_layers[0]
|
||||
layer.hyper_v = hypernetwork_layers[1]
|
||||
|
||||
context_k = hypernetwork_layers[0](context)
|
||||
context_v = hypernetwork_layers[1](context)
|
||||
context_k = hypernetwork_layers[0](context_k)
|
||||
context_v = hypernetwork_layers[1](context_v)
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
def apply_hypernetworks(hypernetworks, context, layer=None):
|
||||
context_k = context
|
||||
context_v = context
|
||||
for hypernetwork in hypernetworks:
|
||||
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
|
||||
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
|
@ -314,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
|||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
|
||||
context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
|
||||
|
@ -378,9 +464,10 @@ def report_statistics(loss_info:dict):
|
|||
print(e)
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
assert name, "Name cannot be empty!"
|
||||
|
||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||
if not overwrite_old:
|
||||
|
@ -389,6 +476,11 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||
if type(layer_structure) == str:
|
||||
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||
|
||||
if use_dropout and dropout_structure and type(dropout_structure) == str:
|
||||
dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
||||
else:
|
||||
dropout_structure = [0] * len(layer_structure)
|
||||
|
||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||
name=name,
|
||||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
|
@ -397,23 +489,27 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||
weight_init=weight_init,
|
||||
add_layer_norm=add_layer_norm,
|
||||
use_dropout=use_dropout,
|
||||
dropout_structure=dropout_structure
|
||||
)
|
||||
hypernet.save(fn)
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||
template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
|
||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||
template_file = template_file.path
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
hypernetwork = Hypernetwork()
|
||||
hypernetwork.load(path)
|
||||
shared.loaded_hypernetworks = [hypernetwork]
|
||||
|
||||
shared.state.job = "train-hypernetwork"
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
|
@ -437,7 +533,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
else:
|
||||
images_dir = None
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
initial_step = hypernetwork.step or 0
|
||||
|
@ -451,16 +546,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
if clip_grad:
|
||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
|
||||
|
||||
if shared.opts.save_training_settings_to_txt:
|
||||
saved_params = dict(
|
||||
model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
|
||||
model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
|
||||
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
||||
)
|
||||
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
||||
|
@ -477,7 +575,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
weights = hypernetwork.weights()
|
||||
hypernetwork.train_mode()
|
||||
hypernetwork.train()
|
||||
|
||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||
if hypernetwork.optimizer_name in optimizer_dict:
|
||||
|
@ -506,6 +604,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
_loss_step = 0 #internal
|
||||
# size = len(ds.indexes)
|
||||
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||
loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
|
||||
# losses = torch.zeros((size,))
|
||||
# previous_mean_losses = [0]
|
||||
# previous_mean_loss = 0
|
||||
|
@ -519,6 +618,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
sd_hijack_checkpoint.add()
|
||||
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
|
@ -555,7 +656,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
|
||||
loss_logging.append(_loss_step)
|
||||
if clip_grad:
|
||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||
|
||||
|
@ -572,7 +673,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
epoch_num = hypernetwork.step // steps_per_epoch
|
||||
epoch_step = hypernetwork.step % steps_per_epoch
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
||||
pbar.set_description(description)
|
||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||
|
@ -583,6 +685,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
|
||||
|
||||
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
epoch_num = hypernetwork.step // len(ds)
|
||||
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
||||
mean_loss = sum(loss_logging) / len(loss_logging)
|
||||
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||
"loss": f"{loss_step:.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
|
@ -591,7 +701,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
hypernetwork.eval_mode()
|
||||
hypernetwork.eval()
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
if torch.cuda.is_available():
|
||||
cuda_rng_state = torch.cuda.get_rng_state_all()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
|
@ -601,6 +715,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
do_not_save_samples=True,
|
||||
)
|
||||
|
||||
p.disable_extra_networks = True
|
||||
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
|
@ -624,9 +740,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
hypernetwork.train_mode()
|
||||
torch.set_rng_state(rng_state)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_rng_state_all(cuda_rng_state)
|
||||
hypernetwork.train()
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
shared.state.assign_current_image(image)
|
||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||
textual_inversion.tensorboard_add_image(tensorboard_writer,
|
||||
f"Validation at epoch {epoch_num}", image,
|
||||
hypernetwork.step)
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
|
@ -646,8 +769,11 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
finally:
|
||||
pbar.leave = False
|
||||
pbar.close()
|
||||
hypernetwork.eval_mode()
|
||||
hypernetwork.eval()
|
||||
#report_statistics(loss_dict)
|
||||
sd_hijack_checkpoint.remove()
|
||||
|
||||
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
|
@ -668,7 +794,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
|||
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
||||
try:
|
||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||
hypernetwork.sd_checkpoint = checkpoint.shorthash
|
||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||
hypernetwork.name = hypernetwork_name
|
||||
hypernetwork.save(filename)
|
||||
|
|
|
@ -9,15 +9,15 @@ from modules import devices, sd_hijack, shared
|
|||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||
|
||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
||||
|
||||
|
||||
def train_hypernetwork(*args):
|
||||
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
shared.loaded_hypernetworks = []
|
||||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||
|
||||
|
@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)}
|
|||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
sd_hijack.apply_optimizations()
|
||||
|
|
|
@ -36,6 +36,8 @@ def image_grid(imgs, batch_size=1, rows=None):
|
|||
else:
|
||||
rows = math.sqrt(len(imgs))
|
||||
rows = round(rows)
|
||||
if rows > len(imgs):
|
||||
rows = len(imgs)
|
||||
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
|
||||
|
@ -195,7 +197,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
|||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
|
||||
ver_texts]
|
||||
|
||||
pad_top = max(hor_text_heights) + line_spacing * 2
|
||||
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
||||
|
||||
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
|
||||
result.paste(im, (pad_left, pad_top))
|
||||
|
@ -605,8 +607,9 @@ def read_info_from_image(image):
|
|||
except ValueError:
|
||||
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
||||
|
||||
items['exif comment'] = exif_comment
|
||||
geninfo = exif_comment
|
||||
if exif_comment:
|
||||
items['exif comment'] = exif_comment
|
||||
geninfo = exif_comment
|
||||
|
||||
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||
'loop', 'background', 'timestamp', 'duration']:
|
||||
|
|
|
@ -16,11 +16,18 @@ import modules.images as images
|
|||
import modules.scripts
|
||||
|
||||
|
||||
def process_batch(p, input_dir, output_dir, args):
|
||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||
processing.fix_seed(p)
|
||||
|
||||
images = shared.listfiles(input_dir)
|
||||
|
||||
is_inpaint_batch = False
|
||||
if inpaint_mask_dir:
|
||||
inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
||||
is_inpaint_batch = len(inpaint_masks) > 0
|
||||
if is_inpaint_batch:
|
||||
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
||||
|
||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
save_normally = output_dir == ''
|
||||
|
@ -43,6 +50,15 @@ def process_batch(p, input_dir, output_dir, args):
|
|||
img = ImageOps.exif_transpose(img)
|
||||
p.init_images = [img] * p.batch_size
|
||||
|
||||
if is_inpaint_batch:
|
||||
# try to find corresponding mask for an image using simple filename matching
|
||||
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
||||
# if not found use first one ("same mask for all images" use-case)
|
||||
if not mask_image_path in inpaint_masks:
|
||||
mask_image_path = inpaint_masks[0]
|
||||
mask_image = Image.open(mask_image_path)
|
||||
p.image_mask = mask_image
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if proc is None:
|
||||
proc = process_images(p)
|
||||
|
@ -59,38 +75,34 @@ def process_batch(p, input_dir, output_dir, args):
|
|||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
is_inpaint = mode == 1
|
||||
is_batch = mode == 2
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args):
|
||||
is_batch = mode == 5
|
||||
|
||||
if is_inpaint:
|
||||
# Drawn mask
|
||||
if mask_mode == 0:
|
||||
is_mask_sketch = isinstance(init_img_with_mask, dict)
|
||||
is_mask_paint = not is_mask_sketch
|
||||
if is_mask_sketch:
|
||||
# Sketch: mask iff. not transparent
|
||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||
else:
|
||||
# Color-sketch: mask iff. painted over
|
||||
image = init_img_with_mask
|
||||
orig = init_img_with_mask_orig or init_img_with_mask
|
||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
|
||||
image = image.convert("RGB")
|
||||
# Uploaded mask
|
||||
else:
|
||||
image = init_img_inpaint
|
||||
mask = init_mask_inpaint
|
||||
# No mask
|
||||
if mode == 0: # img2img
|
||||
image = init_img.convert("RGB")
|
||||
mask = None
|
||||
elif mode == 1: # img2img sketch
|
||||
image = sketch.convert("RGB")
|
||||
mask = None
|
||||
elif mode == 2: # inpaint
|
||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||
image = image.convert("RGB")
|
||||
elif mode == 3: # inpaint sketch
|
||||
image = inpaint_color_sketch
|
||||
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
image = image.convert("RGB")
|
||||
elif mode == 4: # inpaint upload mask
|
||||
image = init_img_inpaint
|
||||
mask = init_mask_inpaint
|
||||
else:
|
||||
image = init_img
|
||||
image = None
|
||||
mask = None
|
||||
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
|
@ -105,7 +117,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
styles=[prompt_style, prompt_style2],
|
||||
styles=prompt_styles,
|
||||
seed=seed,
|
||||
subseed=subseed,
|
||||
subseed_strength=subseed_strength,
|
||||
|
@ -143,7 +155,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||
if is_batch:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
|
||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args)
|
||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
|
||||
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
else:
|
||||
|
|
|
@ -2,15 +2,17 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.hub
|
||||
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import devices, paths, lowvram, modelloader
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||
|
||||
blip_image_eval_size = 384
|
||||
clip_model_name = 'ViT-L/14'
|
||||
|
@ -19,30 +21,76 @@ Category = namedtuple("Category", ["name", "topn", "items"])
|
|||
|
||||
re_topn = re.compile(r"\.top(\d+)\.")
|
||||
|
||||
def category_types():
|
||||
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
||||
|
||||
|
||||
def download_default_clip_interrogate_categories(content_dir):
|
||||
print("Downloading CLIP categories...")
|
||||
|
||||
tmpdir = content_dir + "_tmp"
|
||||
category_types = ["artists", "flavors", "mediums", "movements"]
|
||||
|
||||
try:
|
||||
os.makedirs(tmpdir)
|
||||
for category_type in category_types:
|
||||
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
||||
os.rename(tmpdir, content_dir)
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "downloading default CLIP interrogate categories")
|
||||
finally:
|
||||
if os.path.exists(tmpdir):
|
||||
os.remove(tmpdir)
|
||||
|
||||
|
||||
class InterrogateModels:
|
||||
blip_model = None
|
||||
clip_model = None
|
||||
clip_preprocess = None
|
||||
categories = None
|
||||
dtype = None
|
||||
running_on_cpu = None
|
||||
|
||||
def __init__(self, content_dir):
|
||||
self.categories = []
|
||||
self.loaded_categories = None
|
||||
self.skip_categories = []
|
||||
self.content_dir = content_dir
|
||||
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||
|
||||
if os.path.exists(content_dir):
|
||||
for filename in os.listdir(content_dir):
|
||||
m = re_topn.search(filename)
|
||||
topn = 1 if m is None else int(m.group(1))
|
||||
def categories(self):
|
||||
if not os.path.exists(self.content_dir):
|
||||
download_default_clip_interrogate_categories(self.content_dir)
|
||||
|
||||
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
|
||||
if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
|
||||
return self.loaded_categories
|
||||
|
||||
self.loaded_categories = []
|
||||
|
||||
if os.path.exists(self.content_dir):
|
||||
self.skip_categories = shared.opts.interrogate_clip_skip_categories
|
||||
category_types = []
|
||||
for filename in Path(self.content_dir).glob('*.txt'):
|
||||
category_types.append(filename.stem)
|
||||
if filename.stem in self.skip_categories:
|
||||
continue
|
||||
m = re_topn.search(filename.stem)
|
||||
topn = 1 if m is None else int(m.group(1))
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
lines = [x.strip() for x in file.readlines()]
|
||||
|
||||
self.categories.append(Category(name=filename, topn=topn, items=lines))
|
||||
self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
|
||||
|
||||
return self.loaded_categories
|
||||
|
||||
def create_fake_fairscale(self):
|
||||
class FakeFairscale:
|
||||
def checkpoint_wrapper(self):
|
||||
pass
|
||||
|
||||
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
|
||||
|
||||
def load_blip_model(self):
|
||||
self.create_fake_fairscale()
|
||||
import models.blip
|
||||
|
||||
files = modelloader.load_models(
|
||||
|
@ -106,6 +154,8 @@ class InterrogateModels:
|
|||
def rank(self, image_features, text_array, top_count=1):
|
||||
import clip
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
if shared.opts.interrogate_clip_dict_limit != 0:
|
||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||
|
||||
|
@ -139,7 +189,6 @@ class InterrogateModels:
|
|||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
try:
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
|
@ -159,12 +208,7 @@ class InterrogateModels:
|
|||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
if shared.opts.interrogate_use_builtin_artists:
|
||||
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
|
||||
|
||||
res += ", " + artist[0]
|
||||
|
||||
for name, topn, items in self.categories:
|
||||
for name, topn, items in self.categories():
|
||||
matches = self.rank(image_features, items, top_count=topn)
|
||||
for match, score in matches:
|
||||
if shared.opts.interrogate_return_ranks:
|
||||
|
|
|
@ -10,7 +10,7 @@ from modules.upscaler import Upscaler
|
|||
from modules.paths import script_path, models_path
|
||||
|
||||
|
||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
|
||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
||||
"""
|
||||
A one-and done loader to try finding the desired models in specified directories.
|
||||
|
||||
|
@ -45,6 +45,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||
full_path = file
|
||||
if os.path.isdir(full_path):
|
||||
continue
|
||||
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
||||
continue
|
||||
if len(ext_filter) != 0:
|
||||
model_name, extension = os.path.splitext(file)
|
||||
if extension not in ext_filter:
|
||||
|
|
1459
modules/models/diffusion/ddpm_edit.py
Normal file
1459
modules/models/diffusion/ddpm_edit.py
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -4,7 +4,15 @@ import sys
|
|||
import modules.safe
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
models_path = os.path.join(script_path, "models")
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
cmd_opts_pre = parser.parse_known_args()[0]
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
models_path = os.path.join(data_path, "models")
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
# search for directory of stable diffusion in following places
|
||||
|
@ -38,3 +46,17 @@ for d, must_exist, what, options in path_dirs:
|
|||
else:
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
|
||||
|
||||
class Prioritize:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.path = None
|
||||
|
||||
def __enter__(self):
|
||||
self.path = sys.path.copy()
|
||||
sys.path = [paths[self.name]] + sys.path
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
sys.path = self.path
|
||||
self.path = None
|
||||
|
|
103
modules/postprocessing.py
Normal file
103
modules/postprocessing.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import os
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'extras'
|
||||
|
||||
image_data = []
|
||||
image_names = []
|
||||
outputs = []
|
||||
|
||||
if extras_mode == 1:
|
||||
for img in image_folder:
|
||||
image = Image.open(img)
|
||||
image_data.append(image)
|
||||
image_names.append(os.path.splitext(img.orig_name)[0])
|
||||
elif extras_mode == 2:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||
assert input_dir, 'input directory not selected'
|
||||
|
||||
image_list = shared.listfiles(input_dir)
|
||||
for filename in image_list:
|
||||
try:
|
||||
image = Image.open(filename)
|
||||
except Exception:
|
||||
continue
|
||||
image_data.append(image)
|
||||
image_names.append(filename)
|
||||
else:
|
||||
assert image, 'image not selected'
|
||||
|
||||
image_data.append(image)
|
||||
image_names.append(None)
|
||||
|
||||
if extras_mode == 2 and output_dir != '':
|
||||
outpath = output_dir
|
||||
else:
|
||||
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
||||
|
||||
infotext = ''
|
||||
|
||||
for image, name in zip(image_data, image_names):
|
||||
shared.state.textinfo = name
|
||||
|
||||
existing_pnginfo = image.info or {}
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
||||
|
||||
scripts.scripts_postproc.run(pp, args)
|
||||
|
||||
if opts.use_original_name_batch and name is not None:
|
||||
basename = os.path.splitext(os.path.basename(name))[0]
|
||||
else:
|
||||
basename = ''
|
||||
|
||||
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
pp.image.info = existing_pnginfo
|
||||
pp.image.info["postprocessing"] = infotext
|
||||
|
||||
if save_output:
|
||||
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
||||
|
||||
if extras_mode != 2 or show_extras_results:
|
||||
outputs.append(pp.image)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||
|
||||
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||
"""old handler for API"""
|
||||
|
||||
args = scripts.scripts_postproc.create_args_for_run({
|
||||
"Upscale": {
|
||||
"upscale_mode": resize_mode,
|
||||
"upscale_by": upscaling_resize,
|
||||
"upscale_to_width": upscaling_resize_w,
|
||||
"upscale_to_height": upscaling_resize_h,
|
||||
"upscale_crop": upscaling_crop,
|
||||
"upscaler_1_name": extras_upscaler_1,
|
||||
"upscaler_2_name": extras_upscaler_2,
|
||||
"upscaler_2_visibility": extras_upscaler_2_visibility,
|
||||
},
|
||||
"GFPGAN": {
|
||||
"gfpgan_visibility": gfpgan_visibility,
|
||||
},
|
||||
"CodeFormer": {
|
||||
"codeformer_visibility": codeformer_visibility,
|
||||
"codeformer_weight": codeformer_weight,
|
||||
},
|
||||
})
|
||||
|
||||
return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
|
|
@ -13,10 +13,11 @@ from skimage import exposure
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.paths as paths
|
||||
import modules.face_restoration
|
||||
import modules.images as images
|
||||
import modules.styles
|
||||
|
@ -94,7 +95,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
|||
return image_conditioning
|
||||
|
||||
|
||||
class StableDiffusionProcessing():
|
||||
class StableDiffusionProcessing:
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
"""
|
||||
|
@ -102,7 +103,6 @@ class StableDiffusionProcessing():
|
|||
if sampler_index is not None:
|
||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||
|
||||
self.sd_model = sd_model
|
||||
self.outpath_samples: str = outpath_samples
|
||||
self.outpath_grids: str = outpath_grids
|
||||
self.prompt: str = prompt
|
||||
|
@ -141,6 +141,7 @@ class StableDiffusionProcessing():
|
|||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
||||
self.is_using_inpainting_conditioning = False
|
||||
self.disable_extra_networks = False
|
||||
|
||||
if not seed_enable_extras:
|
||||
self.subseed = -1
|
||||
|
@ -156,6 +157,10 @@ class StableDiffusionProcessing():
|
|||
self.all_subseeds = None
|
||||
self.iteration = 0
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
return shared.sd_model
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
|
||||
|
@ -180,7 +185,12 @@ class StableDiffusionProcessing():
|
|||
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
||||
return conditioning
|
||||
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
|
||||
def edit_image_conditioning(self, source_image):
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
||||
|
||||
return conditioning_image
|
||||
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
# Handle the different mask inputs
|
||||
|
@ -199,7 +209,7 @@ class StableDiffusionProcessing():
|
|||
|
||||
# Create another latent image, this time with a masked version of the original input.
|
||||
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
||||
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
|
||||
conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
|
||||
conditioning_image = torch.lerp(
|
||||
source_image,
|
||||
source_image * (1.0 - conditioning_mask),
|
||||
|
@ -218,11 +228,16 @@ class StableDiffusionProcessing():
|
|||
return image_conditioning
|
||||
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
source_image = devices.cond_cast_float(source_image)
|
||||
|
||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
||||
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
||||
return self.depth2img_image_conditioning(source_image)
|
||||
|
||||
if self.sd_model.cond_stage_key == "edit":
|
||||
return self.edit_image_conditioning(source_image)
|
||||
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||
|
||||
|
@ -236,7 +251,6 @@ class StableDiffusionProcessing():
|
|||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
self.sd_model = None
|
||||
self.sampler = None
|
||||
|
||||
|
||||
|
@ -436,11 +450,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
||||
"Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
|
||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
|
@ -449,8 +458,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||
"VAE": "None" if shared.loaded_vae_file is None else os.path.split(shared.loaded_vae_file)[1].removesuffix(".pt"),
|
||||
"VAE hash": None if shared.loaded_vae_file is None else sd_models.model_hash(shared.loaded_vae_file),
|
||||
"VAE": "None" if sd_vae.loaded_vae_file is None else os.path.split(sd_vae.loaded_vae_file)[1].removesuffix(".pt"),
|
||||
"VAE hash": None if sd_vae.loaded_vae_file is None else sd_models.model_hash(sd_vae.loaded_vae_file),
|
||||
}
|
||||
|
||||
generation_params.update(p.extra_generation_params)
|
||||
|
@ -468,15 +477,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
try:
|
||||
for k, v in p.override_settings.items():
|
||||
setattr(opts, k, v)
|
||||
if k == 'sd_hypernetwork':
|
||||
shared.reload_hypernetworks() # make onchange call for changing hypernet
|
||||
|
||||
if k == 'sd_model_checkpoint':
|
||||
sd_models.reload_model_weights() # make onchange call for changing SD model
|
||||
p.sd_model = shared.sd_model
|
||||
sd_models.reload_model_weights()
|
||||
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights() # make onchange call for changing VAE
|
||||
sd_vae.reload_vae_weights()
|
||||
|
||||
res = process_images_inner(p)
|
||||
|
||||
|
@ -485,9 +491,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
if p.override_settings_restore_afterwards:
|
||||
for k, v in stored_opts.items():
|
||||
setattr(opts, k, v)
|
||||
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
|
||||
if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
|
||||
if k == 'sd_vae': sd_vae.reload_vae_weights()
|
||||
if k == 'sd_model_checkpoint':
|
||||
sd_models.reload_model_weights()
|
||||
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights()
|
||||
|
||||
return res
|
||||
|
||||
|
@ -533,13 +541,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
def infotext(iteration=0, position_in_batch=0):
|
||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
_, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1])
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process(p)
|
||||
|
||||
|
@ -573,6 +579,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
with devices.autocast():
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
|
||||
# for OSX, loading the model during sampling changes the generated picture, so it is loaded here
|
||||
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
||||
sd_vae_approx.model()
|
||||
|
||||
if not p.disable_extra_networks:
|
||||
extra_networks.activate(p, extra_network_data)
|
||||
|
||||
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
if state.job_count == -1:
|
||||
state.job_count = p.n_iter
|
||||
|
||||
|
@ -593,6 +610,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if len(prompts) == 0:
|
||||
break
|
||||
|
||||
prompts, _ = extra_networks.parse_prompts(prompts)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
||||
|
||||
|
@ -606,10 +625,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
with devices.autocast():
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
||||
|
||||
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
||||
for x in x_samples_ddim:
|
||||
devices.test_for_nans(x, "vae")
|
||||
|
||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
|
@ -638,6 +660,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
image = Image.fromarray(x_sample)
|
||||
|
||||
if p.scripts is not None:
|
||||
pp = scripts.PostprocessImageArgs(image)
|
||||
p.scripts.postprocess_image(p, pp)
|
||||
image = pp.image
|
||||
|
||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
|
@ -679,6 +706,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
if not p.disable_extra_networks:
|
||||
extra_networks.deactivate(p, extra_network_data)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
@ -689,6 +719,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
return res
|
||||
|
||||
|
||||
def old_hires_fix_first_pass_dimensions(width, height):
|
||||
"""old algorithm for auto-calculating first pass size"""
|
||||
|
||||
desired_pixel_count = 512 * 512
|
||||
actual_pixel_count = width * height
|
||||
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
||||
width = math.ceil(scale * width / 64) * 64
|
||||
height = math.ceil(scale * height / 64) * 64
|
||||
|
||||
return width, height
|
||||
|
||||
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
|
||||
|
@ -705,16 +747,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
self.hr_upscale_to_y = hr_resize_y
|
||||
|
||||
if firstphase_width != 0 or firstphase_height != 0:
|
||||
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
|
||||
self.hr_scale = self.width / firstphase_width
|
||||
self.hr_upscale_to_x = self.width
|
||||
self.hr_upscale_to_y = self.height
|
||||
self.width = firstphase_width
|
||||
self.height = firstphase_height
|
||||
|
||||
self.truncate_x = 0
|
||||
self.truncate_y = 0
|
||||
self.applied_old_hires_behavior_to = None
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
if self.enable_hr:
|
||||
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
||||
self.hr_resize_x = self.width
|
||||
self.hr_resize_y = self.height
|
||||
self.hr_upscale_to_x = self.width
|
||||
self.hr_upscale_to_y = self.height
|
||||
|
||||
self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
|
||||
self.applied_old_hires_behavior_to = (self.width, self.height)
|
||||
|
||||
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
|
||||
self.extra_generation_params["Hires upscale"] = self.hr_scale
|
||||
self.hr_upscale_to_x = int(self.width * self.hr_scale)
|
||||
|
@ -833,7 +885,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
|
||||
shared.state.nextjob()
|
||||
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
|
||||
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
||||
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
||||
|
||||
|
|
99
modules/progress.py
Normal file
99
modules/progress.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
import base64
|
||||
import io
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
current_task = None
|
||||
pending_tasks = {}
|
||||
finished_tasks = []
|
||||
|
||||
|
||||
def start_task(id_task):
|
||||
global current_task
|
||||
|
||||
current_task = id_task
|
||||
pending_tasks.pop(id_task, None)
|
||||
|
||||
|
||||
def finish_task(id_task):
|
||||
global current_task
|
||||
|
||||
if current_task == id_task:
|
||||
current_task = None
|
||||
|
||||
finished_tasks.append(id_task)
|
||||
if len(finished_tasks) > 16:
|
||||
finished_tasks.pop(0)
|
||||
|
||||
|
||||
def add_task_to_queue(id_job):
|
||||
pending_tasks[id_job] = time.time()
|
||||
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
||||
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
|
||||
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
active: bool = Field(title="Whether the task is being worked on right now")
|
||||
queued: bool = Field(title="Whether the task is in queue")
|
||||
completed: bool = Field(title="Whether the task has already finished")
|
||||
progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta: float = Field(default=None, title="ETA in secs")
|
||||
live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
|
||||
id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
|
||||
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
||||
|
||||
|
||||
def setup_progress_api(app):
|
||||
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
|
||||
|
||||
|
||||
def progressapi(req: ProgressRequest):
|
||||
active = req.id_task == current_task
|
||||
queued = req.id_task in pending_tasks
|
||||
completed = req.id_task in finished_tasks
|
||||
|
||||
if not active:
|
||||
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
|
||||
|
||||
progress = 0
|
||||
|
||||
job_count, job_no = shared.state.job_count, shared.state.job_no
|
||||
sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
|
||||
|
||||
if job_count > 0:
|
||||
progress += job_no / job_count
|
||||
if sampling_steps > 0 and job_count > 0:
|
||||
progress += 1 / job_count * sampling_step / sampling_steps
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
||||
elapsed_since_start = time.time() - shared.state.time_start
|
||||
predicted_duration = elapsed_since_start / progress if progress > 0 else None
|
||||
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
|
||||
|
||||
id_live_preview = req.id_live_preview
|
||||
shared.state.set_current_image()
|
||||
if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
|
||||
image = shared.state.current_image
|
||||
if image is not None:
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="png")
|
||||
live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
|
||||
id_live_preview = shared.state.id_live_preview
|
||||
else:
|
||||
live_preview = None
|
||||
else:
|
||||
live_preview = None
|
||||
|
||||
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
||||
|
|
@ -49,6 +49,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||
[[5, 'a c'], [10, 'a {b|d{ c']]
|
||||
>>> g("((a][:b:c [d:3]")
|
||||
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
||||
>>> g("[a|(b:1.1)]")
|
||||
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
||||
"""
|
||||
|
||||
def collect_steps(steps, tree):
|
||||
|
@ -84,7 +86,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||
yield args[0].value
|
||||
def __default__(self, data, children, meta):
|
||||
for child in children:
|
||||
yield from child
|
||||
yield child
|
||||
return AtStep().transform(tree)
|
||||
|
||||
def get_schedule(prompt):
|
||||
|
@ -272,6 +274,7 @@ re_attention = re.compile(r"""
|
|||
:
|
||||
""", re.X)
|
||||
|
||||
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
|
@ -337,7 +340,11 @@ def parse_prompt_attention(text):
|
|||
elif text == ']' and len(square_brackets) > 0:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
res.append([text, 1.0])
|
||||
parts = re.split(re_break, text)
|
||||
for i, part in enumerate(parts):
|
||||
if i > 0:
|
||||
res.append(["BREAK", -1])
|
||||
res.append([part, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
|
|
@ -38,15 +38,15 @@ class UpscalerRealESRGAN(Upscaler):
|
|||
return img
|
||||
|
||||
info = self.load_model(path)
|
||||
if not os.path.exists(info.data_path):
|
||||
if not os.path.exists(info.local_data_path):
|
||||
print("Unable to load RealESRGAN model: %s" % info.name)
|
||||
return img
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
scale=info.scale,
|
||||
model_path=info.data_path,
|
||||
model_path=info.local_data_path,
|
||||
model=info.model(),
|
||||
half=not cmd_opts.no_half,
|
||||
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
||||
tile=opts.ESRGAN_tile,
|
||||
tile_pad=opts.ESRGAN_tile_overlap,
|
||||
)
|
||||
|
@ -58,17 +58,13 @@ class UpscalerRealESRGAN(Upscaler):
|
|||
|
||||
def load_model(self, path):
|
||||
try:
|
||||
info = None
|
||||
for scaler in self.scalers:
|
||||
if scaler.data_path == path:
|
||||
info = scaler
|
||||
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
|
||||
|
||||
if info is None:
|
||||
print(f"Unable to find model info: {path}")
|
||||
return None
|
||||
|
||||
model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||
info.data_path = model_file
|
||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||
return info
|
||||
except Exception as e:
|
||||
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
||||
|
|
|
@ -2,7 +2,7 @@ import sys
|
|||
import traceback
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
@ -71,7 +71,9 @@ callback_map = dict(
|
|||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
callbacks_image_grid=[],
|
||||
callbacks_infotext_pasted=[],
|
||||
callbacks_script_unloaded=[],
|
||||
callbacks_before_ui=[],
|
||||
)
|
||||
|
||||
|
||||
|
@ -172,6 +174,14 @@ def image_grid_callback(params: ImageGridLoopParams):
|
|||
report_exception(c, 'image_grid')
|
||||
|
||||
|
||||
def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
|
||||
for c in callback_map['callbacks_infotext_pasted']:
|
||||
try:
|
||||
c.callback(infotext, params)
|
||||
except Exception:
|
||||
report_exception(c, 'infotext_pasted')
|
||||
|
||||
|
||||
def script_unloaded_callback():
|
||||
for c in reversed(callback_map['callbacks_script_unloaded']):
|
||||
try:
|
||||
|
@ -180,6 +190,14 @@ def script_unloaded_callback():
|
|||
report_exception(c, 'script_unloaded')
|
||||
|
||||
|
||||
def before_ui_callback():
|
||||
for c in reversed(callback_map['callbacks_before_ui']):
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'before_ui')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
|
@ -290,8 +308,23 @@ def on_image_grid(callback):
|
|||
add_callback(callback_map['callbacks_image_grid'], callback)
|
||||
|
||||
|
||||
def on_infotext_pasted(callback):
|
||||
"""register a function to be called before applying an infotext.
|
||||
The callback is called with two arguments:
|
||||
- infotext: str - raw infotext.
|
||||
- result: Dict[str, any] - parsed infotext parameters.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_infotext_pasted'], callback)
|
||||
|
||||
|
||||
def on_script_unloaded(callback):
|
||||
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
||||
the script did should be reverted here"""
|
||||
|
||||
add_callback(callback_map['callbacks_script_unloaded'], callback)
|
||||
|
||||
|
||||
def on_before_ui(callback):
|
||||
"""register a function to be called before the UI is created."""
|
||||
|
||||
add_callback(callback_map['callbacks_before_ui'], callback)
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import importlib.util
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def load_module(path):
|
||||
with open(path, "r", encoding="utf8") as file:
|
||||
text = file.read()
|
||||
|
||||
compiled = compile(text, path, 'exec')
|
||||
module = ModuleType(os.path.basename(path))
|
||||
exec(compiled, module.__dict__)
|
||||
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
|
||||
return module
|
||||
|
||||
|
|
|
@ -6,12 +6,16 @@ from collections import namedtuple
|
|||
|
||||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules import shared, paths, script_callbacks, extensions, script_loading
|
||||
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
|
||||
class PostprocessImageArgs:
|
||||
def __init__(self, image):
|
||||
self.image = image
|
||||
|
||||
|
||||
class Script:
|
||||
filename = None
|
||||
args_from = None
|
||||
|
@ -65,7 +69,7 @@ class Script:
|
|||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
|
||||
def process(self, p, *args):
|
||||
"""
|
||||
|
@ -100,6 +104,13 @@ class Script:
|
|||
|
||||
pass
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
|
@ -150,9 +161,11 @@ def basedir():
|
|||
return current_basedir
|
||||
|
||||
|
||||
scripts_data = []
|
||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
|
||||
|
||||
scripts_data = []
|
||||
postprocessing_scripts_data = []
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||
|
||||
|
||||
def list_scripts(scriptdirname, extension):
|
||||
|
@ -190,23 +203,31 @@ def list_files_with_name(filename):
|
|||
def load_scripts():
|
||||
global current_basedir
|
||||
scripts_data.clear()
|
||||
postprocessing_scripts_data.clear()
|
||||
script_callbacks.clear_callbacks()
|
||||
|
||||
scripts_list = list_scripts("scripts", ".py")
|
||||
|
||||
syspath = sys.path
|
||||
|
||||
def register_scripts_from_module(module):
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) != type:
|
||||
continue
|
||||
|
||||
if issubclass(script_class, Script):
|
||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||
|
||||
for scriptfile in sorted(scripts_list):
|
||||
try:
|
||||
if scriptfile.basedir != paths.script_path:
|
||||
sys.path = [scriptfile.basedir] + sys.path
|
||||
current_basedir = scriptfile.basedir
|
||||
|
||||
module = script_loading.load_module(scriptfile.path)
|
||||
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
|
||||
script_module = script_loading.load_module(scriptfile.path)
|
||||
register_scripts_from_module(script_module)
|
||||
|
||||
except Exception:
|
||||
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
||||
|
@ -237,11 +258,15 @@ class ScriptRunner:
|
|||
self.infotext_fields = []
|
||||
|
||||
def initialize_scripts(self, is_img2img):
|
||||
from modules import scripts_auto_postprocessing
|
||||
|
||||
self.scripts.clear()
|
||||
self.alwayson_scripts.clear()
|
||||
self.selectable_scripts.clear()
|
||||
|
||||
for script_class, path, basedir in scripts_data:
|
||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||
|
||||
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
||||
script = script_class()
|
||||
script.filename = path
|
||||
script.is_txt2img = not is_img2img
|
||||
|
@ -320,9 +345,23 @@ class ScriptRunner:
|
|||
outputs=[script.group for script in self.selectable_scripts]
|
||||
)
|
||||
|
||||
self.script_load_ctr = 0
|
||||
def onload_script_visibility(params):
|
||||
title = params.get('Script', None)
|
||||
if title:
|
||||
title_index = self.titles.index(title)
|
||||
visibility = title_index == self.script_load_ctr
|
||||
self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
|
||||
return gr.update(visible=visibility)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
|
||||
self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
|
||||
self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
|
||||
|
||||
return inputs
|
||||
|
||||
def run(self, p: StableDiffusionProcessing, *args):
|
||||
def run(self, p, *args):
|
||||
script_index = args[0]
|
||||
|
||||
if script_index == 0:
|
||||
|
@ -376,6 +415,15 @@ class ScriptRunner:
|
|||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess_image(p, pp, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
for script in self.scripts:
|
||||
try:
|
||||
|
@ -413,6 +461,7 @@ class ScriptRunner:
|
|||
|
||||
scripts_txt2img = ScriptRunner()
|
||||
scripts_img2img = ScriptRunner()
|
||||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
scripts_current: ScriptRunner = None
|
||||
|
||||
|
||||
|
@ -423,12 +472,13 @@ def reload_script_body_only():
|
|||
|
||||
|
||||
def reload_scripts():
|
||||
global scripts_txt2img, scripts_img2img
|
||||
global scripts_txt2img, scripts_img2img, scripts_postproc
|
||||
|
||||
load_scripts()
|
||||
|
||||
scripts_txt2img = ScriptRunner()
|
||||
scripts_img2img = ScriptRunner()
|
||||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
|
|
42
modules/scripts_auto_postprocessing.py
Normal file
42
modules/scripts_auto_postprocessing.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from modules import scripts, scripts_postprocessing, shared
|
||||
|
||||
|
||||
class ScriptPostprocessingForMainUI(scripts.Script):
|
||||
def __init__(self, script_postproc):
|
||||
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
|
||||
self.postprocessing_controls = None
|
||||
|
||||
def title(self):
|
||||
return self.script.name
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
self.postprocessing_controls = self.script.ui()
|
||||
return self.postprocessing_controls.values()
|
||||
|
||||
def postprocess_image(self, p, script_pp, *args):
|
||||
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||
pp.info = {}
|
||||
self.script.process(pp, **args_dict)
|
||||
p.extra_generation_params.update(pp.info)
|
||||
script_pp.image = pp.image
|
||||
|
||||
|
||||
def create_auto_preprocessing_script_data():
|
||||
from modules import scripts
|
||||
|
||||
res = []
|
||||
|
||||
for name in shared.opts.postprocessing_enable_in_main_ui:
|
||||
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
|
||||
if script is None:
|
||||
continue
|
||||
|
||||
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
|
||||
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
|
||||
|
||||
return res
|
152
modules/scripts_postprocessing.py
Normal file
152
modules/scripts_postprocessing.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import errors, shared
|
||||
|
||||
|
||||
class PostprocessedImage:
|
||||
def __init__(self, image):
|
||||
self.image = image
|
||||
self.info = {}
|
||||
|
||||
|
||||
class ScriptPostprocessing:
|
||||
filename = None
|
||||
controls = None
|
||||
args_from = None
|
||||
args_to = None
|
||||
|
||||
order = 1000
|
||||
"""scripts will be ordred by this value in postprocessing UI"""
|
||||
|
||||
name = None
|
||||
"""this function should return the title of the script."""
|
||||
|
||||
group = None
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
|
||||
def ui(self):
|
||||
"""
|
||||
This function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||
The return value should be a dictionary that maps parameter names to components used in processing.
|
||||
Values of those components will be passed to process() function.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def process(self, pp: PostprocessedImage, **args):
|
||||
"""
|
||||
This function is called to postprocess the image.
|
||||
args contains a dictionary with all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def image_changed(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
return res
|
||||
except Exception as e:
|
||||
errors.display(e, f"calling {filename}/{funcname}")
|
||||
|
||||
return default
|
||||
|
||||
|
||||
class ScriptPostprocessingRunner:
|
||||
def __init__(self):
|
||||
self.scripts = None
|
||||
self.ui_created = False
|
||||
|
||||
def initialize_scripts(self, scripts_data):
|
||||
self.scripts = []
|
||||
|
||||
for script_class, path, basedir, script_module in scripts_data:
|
||||
script: ScriptPostprocessing = script_class()
|
||||
script.filename = path
|
||||
|
||||
if script.name == "Simple Upscale":
|
||||
continue
|
||||
|
||||
self.scripts.append(script)
|
||||
|
||||
def create_script_ui(self, script, inputs):
|
||||
script.args_from = len(inputs)
|
||||
script.args_to = len(inputs)
|
||||
|
||||
script.controls = wrap_call(script.ui, script.filename, "ui")
|
||||
|
||||
for control in script.controls.values():
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
|
||||
inputs += list(script.controls.values())
|
||||
script.args_to = len(inputs)
|
||||
|
||||
def scripts_in_preferred_order(self):
|
||||
if self.scripts is None:
|
||||
import modules.scripts
|
||||
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
||||
|
||||
scripts_order = shared.opts.postprocessing_operation_order
|
||||
|
||||
def script_score(name):
|
||||
for i, possible_match in enumerate(scripts_order):
|
||||
if possible_match == name:
|
||||
return i
|
||||
|
||||
return len(self.scripts)
|
||||
|
||||
script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
|
||||
|
||||
return sorted(self.scripts, key=lambda x: script_scores[x.name])
|
||||
|
||||
def setup_ui(self):
|
||||
inputs = []
|
||||
|
||||
for script in self.scripts_in_preferred_order():
|
||||
with gr.Box() as group:
|
||||
self.create_script_ui(script, inputs)
|
||||
|
||||
script.group = group
|
||||
|
||||
self.ui_created = True
|
||||
return inputs
|
||||
|
||||
def run(self, pp: PostprocessedImage, args):
|
||||
for script in self.scripts_in_preferred_order():
|
||||
shared.state.job = script.name
|
||||
|
||||
script_args = args[script.args_from:script.args_to]
|
||||
|
||||
process_args = {}
|
||||
for (name, component), value in zip(script.controls.items(), script_args):
|
||||
process_args[name] = value
|
||||
|
||||
script.process(pp, **process_args)
|
||||
|
||||
def create_args_for_run(self, scripts_args):
|
||||
if not self.ui_created:
|
||||
with gr.Blocks(analytics_enabled=False):
|
||||
self.setup_ui()
|
||||
|
||||
scripts = self.scripts_in_preferred_order()
|
||||
args = [None] * max([x.args_to for x in scripts])
|
||||
|
||||
for script in scripts:
|
||||
script_args_dict = scripts_args.get(script.name, None)
|
||||
if script_args_dict is not None:
|
||||
|
||||
for i, name in enumerate(script.controls):
|
||||
args[script.args_from + i] = script_args_dict.get(name, None)
|
||||
|
||||
return args
|
||||
|
||||
def image_changed(self):
|
||||
for script in self.scripts_in_preferred_order():
|
||||
script.image_changed()
|
||||
|
90
modules/sd_disable_initialization.py
Normal file
90
modules/sd_disable_initialization.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
import ldm.modules.encoders.modules
|
||||
import open_clip
|
||||
import torch
|
||||
import transformers.utils.hub
|
||||
|
||||
|
||||
class DisableInitialization:
|
||||
"""
|
||||
When an object of this class enters a `with` block, it starts:
|
||||
- preventing torch's layer initialization functions from working
|
||||
- changes CLIP and OpenCLIP to not download model weights
|
||||
- changes CLIP to not make requests to check if there is a new version of a file you already have
|
||||
|
||||
When it leaves the block, it reverts everything to how it was before.
|
||||
|
||||
Use it like this:
|
||||
```
|
||||
with DisableInitialization():
|
||||
do_things()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.replaced = []
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
self.replaced.append((obj, field, original))
|
||||
setattr(obj, field, func)
|
||||
|
||||
return original
|
||||
|
||||
def __enter__(self):
|
||||
def do_nothing(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
||||
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
||||
|
||||
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
||||
res.name_or_path = pretrained_model_name_or_path
|
||||
return res
|
||||
|
||||
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
||||
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
||||
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
||||
|
||||
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
||||
|
||||
# this file is always 404, prevent making request
|
||||
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
||||
return None
|
||||
|
||||
try:
|
||||
res = original(url, *args, local_files_only=True, **kwargs)
|
||||
if res is None:
|
||||
res = original(url, *args, local_files_only=False, **kwargs)
|
||||
return res
|
||||
except Exception as e:
|
||||
return original(url, *args, local_files_only=False, **kwargs)
|
||||
|
||||
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
||||
|
||||
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
||||
|
||||
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
||||
|
||||
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for obj, field, original in self.replaced:
|
||||
setattr(obj, field, original)
|
||||
|
||||
self.replaced.clear()
|
||||
|
|
@ -70,9 +70,10 @@ def undo_optimizations():
|
|||
|
||||
|
||||
def fix_checkpoint():
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
|
||||
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||
checkpoints to be added when not training (there's a warning)"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
|
@ -106,8 +107,6 @@ class StableDiffusionModelHijack:
|
|||
self.optimization_method = apply_optimizations()
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
fix_checkpoint()
|
||||
|
||||
def flatten(el):
|
||||
flattened = [flatten(children) for children in el.children()]
|
||||
|
@ -132,6 +131,8 @@ class StableDiffusionModelHijack:
|
|||
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
undo_optimizations()
|
||||
|
||||
self.apply_circular(False)
|
||||
self.layers = None
|
||||
self.clip = None
|
||||
|
@ -172,7 +173,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
|||
vecs = []
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = embedding.vec
|
||||
emb = devices.cond_cast_unet(embedding.vec)
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||
|
||||
|
|
|
@ -1,10 +1,46 @@
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
|
||||
|
||||
def BasicTransformerBlock_forward(self, x, context=None):
|
||||
return checkpoint(self._forward, x, context)
|
||||
|
||||
|
||||
def AttentionBlock_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
|
||||
def ResBlock_forward(self, x, emb):
|
||||
return checkpoint(self._forward, x, emb)
|
||||
return checkpoint(self._forward, x, emb)
|
||||
|
||||
|
||||
stored = []
|
||||
|
||||
|
||||
def add():
|
||||
if len(stored) != 0:
|
||||
return
|
||||
|
||||
stored.extend([
|
||||
ldm.modules.attention.BasicTransformerBlock.forward,
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
|
||||
])
|
||||
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
||||
|
||||
|
||||
def remove():
|
||||
if len(stored) == 0:
|
||||
return
|
||||
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
|
||||
|
||||
stored.clear()
|
||||
|
||||
|
|
|
@ -96,13 +96,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||
token_count = 0
|
||||
last_comma = -1
|
||||
|
||||
def next_chunk():
|
||||
"""puts current chunk into the list of results and produces the next one - empty"""
|
||||
def next_chunk(is_last=False):
|
||||
"""puts current chunk into the list of results and produces the next one - empty;
|
||||
if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
|
||||
nonlocal token_count
|
||||
nonlocal last_comma
|
||||
nonlocal chunk
|
||||
|
||||
token_count += len(chunk.tokens)
|
||||
if is_last:
|
||||
token_count += len(chunk.tokens)
|
||||
else:
|
||||
token_count += self.chunk_length
|
||||
|
||||
to_add = self.chunk_length - len(chunk.tokens)
|
||||
if to_add > 0:
|
||||
chunk.tokens += [self.id_end] * to_add
|
||||
|
@ -116,6 +121,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||
chunk = PromptChunk()
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
if text == 'BREAK' and weight == -1:
|
||||
next_chunk()
|
||||
continue
|
||||
|
||||
position = 0
|
||||
while position < len(tokens):
|
||||
token = tokens[position]
|
||||
|
@ -159,7 +168,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||
position += embedding_length_in_tokens
|
||||
|
||||
if len(chunk.tokens) > 0 or len(chunks) == 0:
|
||||
next_chunk()
|
||||
next_chunk(is_last=True)
|
||||
|
||||
return chunks, token_count
|
||||
|
||||
|
|
|
@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||
return x_prev, pred_x0, e_t
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
from modules import sd_models
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||
|
||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
|
||||
|
|
13
modules/sd_hijack_ip2p.py
Normal file
13
modules/sd_hijack_ip2p.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import time
|
||||
|
||||
def should_hijack_ip2p(checkpoint_info):
|
||||
from modules import sd_models_config
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
||||
|
||||
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
|
|
@ -9,7 +9,7 @@ from torch import einsum
|
|||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
|
||||
from modules import shared
|
||||
from modules import shared, errors, devices
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
for i in range(0, q.shape[0], 2):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], 2):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
del q, k, v
|
||||
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
del q, k, v
|
||||
r1 = r1.to(dtype)
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
@ -78,49 +85,56 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
k_in *= self.scale
|
||||
dtype = q_in.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
|
||||
|
||||
del context, x
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k_in = k_in * self.scale
|
||||
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
r1 = r1.to(dtype)
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
@ -203,13 +217,21 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k = self.to_k(context_k) * self.scale
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = einsum_op(q, k, v)
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k = k * self.scale
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = einsum_op(q, k, v)
|
||||
r = r.to(dtype)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
|
||||
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
||||
|
@ -225,7 +247,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
|
||||
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||
|
||||
x = x.to(dtype)
|
||||
|
||||
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
||||
|
||||
out_proj, dropout = self.to_out
|
||||
|
@ -268,15 +296,31 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
|
||||
return efficient_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_chunk_size=q_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min = kv_chunk_size_min,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||
return efficient_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_chunk_size=q_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min = kv_chunk_size_min,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def get_xformers_flash_attention_op(q, k, v):
|
||||
if not shared.cmd_opts.xformers_flash_attention:
|
||||
return None
|
||||
|
||||
try:
|
||||
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
||||
fw, bw = flash_attention_op
|
||||
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
|
||||
return flash_attention_op
|
||||
except Exception as e:
|
||||
errors.display_once(e, "enabling flash attention")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
|
@ -284,13 +328,20 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
||||
|
||||
out = out.to(dtype)
|
||||
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
@ -362,10 +413,14 @@ def xformers_attnblock_forward(self, x):
|
|||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
|
||||
out = out.to(dtype)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||
out = self.proj_out(out)
|
||||
return x + out
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
import torch
|
||||
from packaging import version
|
||||
|
||||
from modules import devices
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
|
||||
|
||||
class TorchHijackForUnet:
|
||||
|
@ -28,3 +32,37 @@ class TorchHijackForUnet:
|
|||
|
||||
|
||||
th = TorchHijackForUnet()
|
||||
|
||||
|
||||
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||
|
||||
if isinstance(cond, dict):
|
||||
for y in cond.keys():
|
||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
|
||||
with devices.autocast():
|
||||
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
||||
|
||||
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||
def forward(self, x):
|
||||
if devices.unet_needs_upcast:
|
||||
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||
else:
|
||||
return torch.nn.GELU.forward(self, x)
|
||||
|
||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.1"):
|
||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
|
||||
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||
|
|
28
modules/sd_hijack_utils.py
Normal file
28
modules/sd_hijack_utils.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import importlib
|
||||
|
||||
class CondFunc:
|
||||
def __new__(cls, orig_func, sub_func, cond_func):
|
||||
self = super(CondFunc, cls).__new__(cls)
|
||||
if isinstance(orig_func, str):
|
||||
func_path = orig_func.split('.')
|
||||
for i in range(len(func_path)-1, -1, -1):
|
||||
try:
|
||||
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
||||
break
|
||||
except ImportError:
|
||||
pass
|
||||
for attr_name in func_path[i:-1]:
|
||||
resolved_obj = getattr(resolved_obj, attr_name)
|
||||
orig_func = getattr(resolved_obj, func_path[-1])
|
||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
||||
self.__init__(orig_func, sub_func, cond_func)
|
||||
return lambda *args, **kwargs: self(*args, **kwargs)
|
||||
def __init__(self, orig_func, sub_func, cond_func):
|
||||
self.__orig_func = orig_func
|
||||
self.__sub_func = sub_func
|
||||
self.__cond_func = cond_func
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
||||
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
||||
else:
|
||||
return self.__orig_func(*args, **kwargs)
|
|
@ -2,7 +2,6 @@ import collections
|
|||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
|
@ -13,17 +12,64 @@ import ldm.modules.midas as midas
|
|||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||
from modules.paths import models_path
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules.timer import Timer
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||
checkpoints_list = {}
|
||||
checkpoint_alisases = {}
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
|
||||
class CheckpointInfo:
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
abspath = os.path.abspath(filename)
|
||||
|
||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||
elif abspath.startswith(model_path):
|
||||
name = abspath.replace(model_path, '')
|
||||
else:
|
||||
name = os.path.basename(filename)
|
||||
|
||||
if name.startswith("\\") or name.startswith("/"):
|
||||
name = name[1:]
|
||||
|
||||
self.name = name
|
||||
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
self.hash = model_hash(filename)
|
||||
|
||||
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
|
||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||
|
||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||
|
||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||
|
||||
def register(self):
|
||||
checkpoints_list[self.title] = self
|
||||
for id in self.ids:
|
||||
checkpoint_alisases[id] = self
|
||||
|
||||
def calculate_shorthash(self):
|
||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
||||
self.shorthash = self.sha256[0:10]
|
||||
|
||||
if self.shorthash not in self.ids:
|
||||
self.ids += [self.shorthash, self.sha256]
|
||||
self.register()
|
||||
|
||||
self.title = f'{self.name} [{self.shorthash}]'
|
||||
|
||||
return self.shorthash
|
||||
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
|
||||
|
@ -42,64 +88,50 @@ def setup_model():
|
|||
enable_midas_autodownload()
|
||||
|
||||
|
||||
def checkpoint_tiles():
|
||||
convert = lambda name: int(name) if name.isdigit() else name.lower()
|
||||
alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
|
||||
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
||||
def checkpoint_tiles():
|
||||
def convert(name):
|
||||
return int(name) if name.isdigit() else name.lower()
|
||||
|
||||
def alphanumeric_key(key):
|
||||
return [convert(c) for c in re.split('([0-9]+)', key)]
|
||||
|
||||
def find_checkpoint_config(info):
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return shared.cmd_opts.config
|
||||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
||||
|
||||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
|
||||
|
||||
def modeltitle(path, shorthash):
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||
elif abspath.startswith(model_path):
|
||||
name = abspath.replace(model_path, '')
|
||||
else:
|
||||
name = os.path.basename(path)
|
||||
|
||||
if name.startswith("\\") or name.startswith("/"):
|
||||
name = name[1:]
|
||||
|
||||
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
|
||||
return f'{name} [{shorthash}]', shortname
|
||||
checkpoint_alisases.clear()
|
||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
|
||||
|
||||
cmd_ckpt = shared.cmd_opts.ckpt
|
||||
if os.path.exists(cmd_ckpt):
|
||||
h = model_hash(cmd_ckpt)
|
||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
||||
shared.opts.data['sd_model_checkpoint'] = title
|
||||
checkpoint_info = CheckpointInfo(cmd_ckpt)
|
||||
checkpoint_info.register()
|
||||
|
||||
shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||
|
||||
for filename in model_list:
|
||||
h = model_hash(filename)
|
||||
title, short_model_name = modeltitle(filename, h)
|
||||
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
||||
checkpoint_info = CheckpointInfo(filename)
|
||||
checkpoint_info.register()
|
||||
|
||||
|
||||
def get_closet_checkpoint_match(searchString):
|
||||
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
|
||||
if len(applicable) > 0:
|
||||
return applicable[0]
|
||||
def get_closet_checkpoint_match(search_string):
|
||||
checkpoint_info = checkpoint_alisases.get(search_string, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
||||
found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
|
||||
if found:
|
||||
return found[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def model_hash(filename):
|
||||
"""old hash that only looks at a small part of the file and is prone to collisions"""
|
||||
|
||||
try:
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
|
@ -115,7 +147,7 @@ def model_hash(filename):
|
|||
def select_checkpoint():
|
||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||
|
||||
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
|
||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
||||
|
@ -171,9 +203,7 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
device = map_location or shared.weight_load_location
|
||||
if device is None:
|
||||
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
|
||||
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
@ -185,61 +215,85 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
|||
return sd
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
sd_model_hash = checkpoint_info.hash
|
||||
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||
|
||||
if cache_enabled and checkpoint_info in checkpoints_loaded:
|
||||
if checkpoint_info in checkpoints_loaded:
|
||||
# use checkpoint cache
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||
else:
|
||||
# load from file
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||
return checkpoints_loaded[checkpoint_info]
|
||||
|
||||
sd = read_state_dict(checkpoint_file)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
del sd
|
||||
|
||||
if cache_enabled:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
res = read_state_dict(checkpoint_info.filename)
|
||||
timer.record("load weights from disk")
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
return res
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
vae = model.first_stage_model
|
||||
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
model.half()
|
||||
model.first_stage_model = vae
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
timer.record("apply weights to model")
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
timer.record("apply channels_last")
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
vae = model.first_stage_model
|
||||
depth_model = getattr(model, 'depth_model', None)
|
||||
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
# with --upcast-sampling, don't convert the depth model weights to float16
|
||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||
model.depth_model = None
|
||||
|
||||
model.half()
|
||||
model.first_stage_model = vae
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
|
||||
timer.record("apply half()")
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
devices.dtype_unet = model.model.diffusion_model.dtype
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
timer.record("apply dtype to VAE")
|
||||
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False)
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpoint = checkpoint_file
|
||||
model.sd_model_checkpoint = checkpoint_info.filename
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||
|
||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
sd_vae.load_vae(model, vae_file)
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
||||
sd_vae.load_vae(model, vae_file, vae_source)
|
||||
timer.record("load VAE")
|
||||
|
||||
|
||||
def enable_midas_autodownload():
|
||||
|
@ -252,7 +306,7 @@ def enable_midas_autodownload():
|
|||
location automatically.
|
||||
"""
|
||||
|
||||
midas_path = os.path.join(models_path, 'midas')
|
||||
midas_path = os.path.join(paths.models_path, 'midas')
|
||||
|
||||
# stable-diffusion-stability-ai hard-codes the midas model path to
|
||||
# a location that differs from where other scripts using this model look.
|
||||
|
@ -285,13 +339,20 @@ def enable_midas_autodownload():
|
|||
midas.api.load_model = load_model_wrapper
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None):
|
||||
def repair_config(sd_config):
|
||||
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
checkpoint_config = find_checkpoint_config(checkpoint_info)
|
||||
|
||||
if checkpoint_config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_config}")
|
||||
|
||||
if shared.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
|
@ -299,42 +360,66 @@ def load_model(checkpoint_info=None):
|
|||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
|
||||
if should_hijack_inpainting(checkpoint_info):
|
||||
# Hardcoded config for now...
|
||||
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
sd_config.model.params.conditioning_key = "hybrid"
|
||||
sd_config.model.params.unet_config.params.in_channels = 9
|
||||
sd_config.model.params.finetune_keys = None
|
||||
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
do_inpainting_hijack()
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
timer = Timer()
|
||||
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
if already_loaded_state_dict is not None:
|
||||
state_dict = already_loaded_state_dict
|
||||
else:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
repair_config(sd_config)
|
||||
|
||||
timer.record("load config")
|
||||
|
||||
print(f"Creating model from config: {checkpoint_config}")
|
||||
|
||||
sd_model = None
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if sd_model is None:
|
||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
sd_model.used_config = checkpoint_config
|
||||
|
||||
timer.record("create model")
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||
else:
|
||||
sd_model.to(shared.device)
|
||||
|
||||
timer.record("move model to device")
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
timer.record("hijack")
|
||||
|
||||
sd_model.eval()
|
||||
shared.sd_model = sd_model
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||
|
||||
timer.record("load textual inversion embeddings")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
print("Model loaded.")
|
||||
timer.record("scripts callbacks")
|
||||
|
||||
print(f"Model loaded in {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
||||
|
@ -346,38 +431,51 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
|
||||
if sd_model is None: # previous model load failed
|
||||
current_checkpoint_info = None
|
||||
else:
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
timer = Timer()
|
||||
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info)
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
|
||||
return shared.sd_model
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
except Exception as e:
|
||||
print("Failed to load checkpoint, restoring previous")
|
||||
load_model_weights(sd_model, current_checkpoint_info)
|
||||
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
timer.record("hijack")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
timer.record("script callbacks")
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
timer.record("move model to device")
|
||||
|
||||
print("Weights loaded.")
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
|
112
modules/sd_models_config.py
Normal file
112
modules/sd_models_config.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import re
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared, paths, sd_disable_initialization
|
||||
|
||||
sd_configs_path = shared.sd_configs_path
|
||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
|
||||
|
||||
config_default = shared.sd_default_config
|
||||
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
|
||||
|
||||
def is_using_v_parameterization_for_sd2(state_dict):
|
||||
"""
|
||||
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
||||
"""
|
||||
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
from modules import devices
|
||||
|
||||
device = devices.cpu
|
||||
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
||||
use_checkpoint=True,
|
||||
use_fp16=False,
|
||||
image_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
model_channels=320,
|
||||
attention_resolutions=[4, 2, 1],
|
||||
num_res_blocks=2,
|
||||
channel_mult=[1, 2, 4, 4],
|
||||
num_head_channels=64,
|
||||
use_spatial_transformer=True,
|
||||
use_linear_in_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=1024,
|
||||
legacy=False
|
||||
)
|
||||
unet.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
||||
unet.load_state_dict(unet_sd, strict=True)
|
||||
unet.to(device=device, dtype=torch.float)
|
||||
|
||||
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
||||
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
||||
|
||||
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
|
||||
|
||||
return out < -1
|
||||
|
||||
|
||||
def guess_model_config_from_state_dict(sd, filename):
|
||||
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
|
||||
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
return config_depth_model
|
||||
|
||||
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_sd2_inpainting
|
||||
elif is_using_v_parameterization_for_sd2(sd):
|
||||
return config_sd2v
|
||||
else:
|
||||
return config_sd2
|
||||
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_inpainting
|
||||
if diffusion_model_input.shape[1] == 8:
|
||||
return config_instruct_pix2pix
|
||||
|
||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||
return config_alt_diffusion
|
||||
|
||||
return config_default
|
||||
|
||||
|
||||
def find_checkpoint_config(state_dict, info):
|
||||
if info is None:
|
||||
return guess_model_config_from_state_dict(state_dict, "")
|
||||
|
||||
config = find_checkpoint_config_near_filename(info)
|
||||
if config is not None:
|
||||
return config
|
||||
|
||||
return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||
|
||||
|
||||
def find_checkpoint_config_near_filename(info):
|
||||
if info is None:
|
||||
return None
|
||||
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return None
|
||||
|
|
@ -138,9 +138,9 @@ def samples_to_image_grid(samples, approximation=None):
|
|||
def store_latent(decoded):
|
||||
state.current_latent = decoded
|
||||
|
||||
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if not shared.parallel_processing_allowed:
|
||||
shared.state.current_image = sample_to_image(decoded)
|
||||
shared.state.assign_current_image(sample_to_image(decoded))
|
||||
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
|
@ -243,7 +243,7 @@ class VanillaStableDiffusionSampler:
|
|||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
@ -266,8 +266,7 @@ class VanillaStableDiffusionSampler:
|
|||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
return samples
|
||||
|
@ -352,6 +351,13 @@ class CFGDenoiser(torch.nn.Module):
|
|||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
store_latent(x_out[0:uncond.shape[0]])
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
|
@ -423,7 +429,8 @@ class KDiffusionSampler:
|
|||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
store_latent(latent)
|
||||
if opts.live_preview_content == "Combined":
|
||||
store_latent(latent)
|
||||
self.last_latent = latent
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
|
@ -447,7 +454,7 @@ class KDiffusionSampler:
|
|||
def initialize(self, p):
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap.step = 0
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.eta = p.eta or opts.eta_ancestral
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
|
|
@ -1,37 +1,25 @@
|
|||
import torch
|
||||
import safetensors.torch
|
||||
import os
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from modules import shared, devices, script_callbacks, sd_models
|
||||
from modules.paths import models_path
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
vae_dir = "VAE"
|
||||
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
|
||||
|
||||
|
||||
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
|
||||
|
||||
default_vae_dict = {"auto": "auto", "None": None, None: None}
|
||||
default_vae_list = ["auto", "None"]
|
||||
|
||||
|
||||
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
|
||||
vae_dict = dict(default_vae_dict)
|
||||
vae_list = list(default_vae_list)
|
||||
vae_hash_to_filename = defaultdict(list)
|
||||
first_load = True
|
||||
vae_dict = {}
|
||||
|
||||
|
||||
base_vae = None
|
||||
loaded_vae_file = None
|
||||
checkpoint_info = None
|
||||
|
||||
vae_hash_to_filename = defaultdict(list)
|
||||
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
def get_base_vae(model):
|
||||
|
@ -43,7 +31,7 @@ def get_base_vae(model):
|
|||
def store_base_vae(model):
|
||||
global base_vae, checkpoint_info
|
||||
if checkpoint_info != model.sd_checkpoint_info:
|
||||
assert not shared.loaded_vae_file, "Trying to store non-base VAE!"
|
||||
assert not loaded_vae_file, "Trying to store non-base VAE!"
|
||||
base_vae = deepcopy(model.first_stage_model.state_dict())
|
||||
checkpoint_info = model.sd_checkpoint_info
|
||||
|
||||
|
@ -55,103 +43,95 @@ def delete_base_vae():
|
|||
|
||||
|
||||
def restore_base_vae(model):
|
||||
global loaded_vae_file
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
||||
print("Restoring base VAE")
|
||||
_load_vae_dict(model, base_vae)
|
||||
shared.loaded_vae_file = None
|
||||
shared.opts.sd_vae = "None"
|
||||
loaded_vae_file = None
|
||||
delete_base_vae()
|
||||
|
||||
|
||||
def get_filename(filepath):
|
||||
return os.path.splitext(os.path.basename(filepath))[0]
|
||||
return os.path.basename(filepath)
|
||||
|
||||
|
||||
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
||||
global vae_dict, vae_list, vae_hash_to_filename
|
||||
res = {}
|
||||
candidates = [
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
|
||||
def refresh_vae_list():
|
||||
vae_dict.clear()
|
||||
|
||||
paths = [
|
||||
os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
|
||||
os.path.join(sd_models.model_path, '**/*.vae.pt'),
|
||||
os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
|
||||
os.path.join(vae_path, '**/*.ckpt'),
|
||||
os.path.join(vae_path, '**/*.pt'),
|
||||
os.path.join(vae_path, '**/*.safetensors'),
|
||||
]
|
||||
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
|
||||
candidates.append(shared.cmd_opts.vae_path)
|
||||
|
||||
if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
|
||||
paths += [
|
||||
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
|
||||
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
|
||||
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
|
||||
]
|
||||
|
||||
if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
|
||||
paths += [
|
||||
os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
|
||||
os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
|
||||
os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
|
||||
]
|
||||
|
||||
candidates = []
|
||||
for path in paths:
|
||||
candidates += glob.iglob(path, recursive=True)
|
||||
|
||||
vae_hash_to_filename.clear()
|
||||
for filepath in candidates:
|
||||
name = get_filename(filepath)
|
||||
res[name] = filepath
|
||||
vae_dict[name] = filepath
|
||||
vae_hash_to_filename[sd_models.model_hash(filepath)].append(name)
|
||||
vae_list.clear()
|
||||
vae_list.extend(default_vae_list)
|
||||
vae_list.extend(list(res.keys()))
|
||||
vae_dict.clear()
|
||||
vae_dict.update(res)
|
||||
vae_dict.update(default_vae_dict)
|
||||
return vae_list
|
||||
|
||||
|
||||
def get_vae_from_settings(vae_file="auto"):
|
||||
# else, we load from settings, if not set to be default
|
||||
if vae_file == "auto" and shared.opts.sd_vae is not None:
|
||||
# if saved VAE settings isn't recognized, fallback to auto
|
||||
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
|
||||
# if VAE selected but not found, fallback to auto
|
||||
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
|
||||
vae_file = "auto"
|
||||
print(f"Selected VAE doesn't exist: {vae_file}")
|
||||
return vae_file
|
||||
def find_vae_near_checkpoint(checkpoint_file):
|
||||
checkpoint_path = os.path.splitext(checkpoint_file)[0]
|
||||
for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
|
||||
if os.path.isfile(vae_location):
|
||||
return vae_location
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def resolve_vae(checkpoint_file=None, vae_file="auto"):
|
||||
global first_load, vae_dict, vae_list
|
||||
def resolve_vae(checkpoint_file):
|
||||
if shared.cmd_opts.vae_path is not None:
|
||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
||||
|
||||
# if vae_file argument is provided, it takes priority, but not saved
|
||||
if vae_file and vae_file not in default_vae_list:
|
||||
if not os.path.isfile(vae_file):
|
||||
print(f"VAE provided as function argument doesn't exist: {vae_file}")
|
||||
vae_file = "auto"
|
||||
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
|
||||
if first_load and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
shared.opts.data['sd_vae'] = get_filename(vae_file)
|
||||
else:
|
||||
print(f"VAE provided as command line argument doesn't exist: {vae_file}")
|
||||
# fallback to selector in settings, if vae selector not set to act as default fallback
|
||||
if not shared.opts.sd_vae_as_default:
|
||||
vae_file = get_vae_from_settings(vae_file)
|
||||
# vae-path cmd arg takes priority for auto
|
||||
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
print(f"Using VAE provided as command line argument: {vae_file}")
|
||||
# if still not found, try look for ".vae.pt" beside model
|
||||
model_path = os.path.splitext(checkpoint_file)[0]
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.pt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# if still not found, try look for ".vae.ckpt" beside model
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.ckpt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# No more fallbacks for auto
|
||||
if vae_file == "auto":
|
||||
vae_file = None
|
||||
# Last check, just because
|
||||
if vae_file and not os.path.exists(vae_file):
|
||||
vae_file = None
|
||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||
|
||||
return vae_file
|
||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
|
||||
return vae_near_checkpoint, 'found near the checkpoint'
|
||||
|
||||
if shared.opts.sd_vae == "None":
|
||||
return None, None
|
||||
|
||||
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
||||
if vae_from_options is not None:
|
||||
return vae_from_options, 'specified in settings'
|
||||
|
||||
if not is_automatic:
|
||||
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def load_vae(model, vae_file=None):
|
||||
global first_load, vae_dict, vae_list
|
||||
def load_vae_dict(filename, map_location):
|
||||
vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
return vae_dict_1
|
||||
|
||||
|
||||
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||
global vae_dict, loaded_vae_file
|
||||
# save_settings = False
|
||||
|
||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||
|
@ -159,15 +139,15 @@ def load_vae(model, vae_file=None):
|
|||
if vae_file:
|
||||
if cache_enabled and vae_file in checkpoints_loaded:
|
||||
# use vae checkpoint cache
|
||||
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
|
||||
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
|
||||
store_base_vae(model)
|
||||
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||
else:
|
||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
||||
store_base_vae(model)
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
|
||||
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
|
||||
if cache_enabled:
|
||||
|
@ -184,12 +164,12 @@ def load_vae(model, vae_file=None):
|
|||
vae_opt = get_filename(vae_file)
|
||||
if vae_opt not in vae_dict:
|
||||
vae_dict[vae_opt] = vae_file
|
||||
vae_list.append(vae_opt)
|
||||
shared.loaded_vae_file = vae_file
|
||||
elif shared.loaded_vae_file:
|
||||
vae_hash_to_filename[sd_models.model_hash(vae_file)] = vae_opt
|
||||
|
||||
elif loaded_vae_file:
|
||||
restore_base_vae(model)
|
||||
|
||||
first_load = False
|
||||
loaded_vae_file = vae_file
|
||||
|
||||
|
||||
# don't call this from outside
|
||||
|
@ -197,10 +177,16 @@ def _load_vae_dict(model, vae_dict_1):
|
|||
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
def clear_loaded_vae():
|
||||
shared.loaded_vae_file = None
|
||||
|
||||
def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||
def clear_loaded_vae():
|
||||
global loaded_vae_file
|
||||
loaded_vae_file = None
|
||||
|
||||
|
||||
unspecified = object()
|
||||
|
||||
|
||||
def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
|
||||
if not sd_model:
|
||||
|
@ -208,9 +194,13 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
|||
|
||||
checkpoint_info = sd_model.sd_checkpoint_info
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
|
||||
if shared.loaded_vae_file == vae_file:
|
||||
if vae_file == unspecified:
|
||||
vae_file, vae_source = resolve_vae(checkpoint_file)
|
||||
else:
|
||||
vae_source = "from function argument"
|
||||
|
||||
if loaded_vae_file == vae_file:
|
||||
return
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
|
@ -220,7 +210,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
|||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_vae(sd_model, vae_file)
|
||||
load_vae(sd_model, vae_file, vae_source)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
@ -228,12 +218,16 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
|
|||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print("VAE Weights loaded.")
|
||||
print("VAE weights loaded.")
|
||||
return sd_model
|
||||
|
||||
|
||||
def is_valid_vae(vae_file: str):
|
||||
return vae_file in vae_dict
|
||||
"""
|
||||
Returns true if the vae_file name exists in the cache of vae files
|
||||
A vae_file of "None" is valid because it represents the "None" option in the vae_se
|
||||
"""
|
||||
return vae_file in vae_dict or vae_file == "None"
|
||||
|
||||
|
||||
def find_vae_key(vae_name, vae_hash=None):
|
||||
|
@ -241,11 +235,13 @@ def find_vae_key(vae_name, vae_hash=None):
|
|||
If vae_hash is provided, this function will return the name of any local VAE file that matches the hash.
|
||||
If vae_hash is None, this function will return vae_name if any local VAE files are named vae_name
|
||||
"""
|
||||
if vae_name == "None":
|
||||
return vae_name
|
||||
if vae_hash is not None and (matched := vae_hash_to_filename.get(vae_hash)):
|
||||
if vae_name in matched or vae_name.lower() in matched:
|
||||
return vae_name
|
||||
return matched[0]
|
||||
else:
|
||||
if vae_name.lower() in [vae_filename.lower() for vae_filename in vae_list]:
|
||||
if vae_name.lower() in [vae_filename.lower() for vae_filename in vae_dict.keys()]:
|
||||
return vae_name
|
||||
return None
|
||||
|
|
|
@ -36,7 +36,7 @@ def model():
|
|||
|
||||
if sd_vae_approx_model is None:
|
||||
sd_vae_approx_model = VAEApprox()
|
||||
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
|
||||
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
sd_vae_approx_model.eval()
|
||||
sd_vae_approx_model.to(devices.device, devices.dtype)
|
||||
|
||||
|
|
|
@ -9,30 +9,35 @@ from PIL import Image
|
|||
import gradio as gr
|
||||
import tqdm
|
||||
|
||||
import modules.artists
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, sd_vae, extensions, script_loading, errors
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
|
||||
from modules.paths import models_path, script_path, data_path
|
||||
|
||||
|
||||
demo = None
|
||||
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
|
@ -42,6 +47,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
|
|||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
|
@ -54,6 +60,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
|
|||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
|
@ -63,26 +70,27 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage
|
|||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
|
||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
|
@ -96,6 +104,8 @@ parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS o
|
|||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
|
||||
|
||||
|
||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
|
||||
|
@ -115,6 +125,7 @@ restricted_opts = {
|
|||
}
|
||||
|
||||
ui_reorder_categories = [
|
||||
"inpaint",
|
||||
"sampler",
|
||||
"dimensions",
|
||||
"cfg",
|
||||
|
@ -140,9 +151,7 @@ config_filename = cmd_opts.ui_settings_file
|
|||
|
||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||
hypernetworks = {}
|
||||
loaded_hypernetwork = None
|
||||
|
||||
loaded_vae_file = None
|
||||
loaded_hypernetworks = []
|
||||
|
||||
|
||||
def reload_hypernetworks():
|
||||
|
@ -150,7 +159,6 @@ def reload_hypernetworks():
|
|||
global hypernetworks
|
||||
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
|
||||
|
||||
|
||||
class State:
|
||||
|
@ -166,9 +174,11 @@ class State:
|
|||
current_latent = None
|
||||
current_image = None
|
||||
current_image_sampling_step = 0
|
||||
id_live_preview = 0
|
||||
textinfo = None
|
||||
time_start = None
|
||||
need_restart = False
|
||||
server_start = None
|
||||
|
||||
def skip(self):
|
||||
self.skipped = True
|
||||
|
@ -177,7 +187,7 @@ class State:
|
|||
self.interrupted = True
|
||||
|
||||
def nextjob(self):
|
||||
if opts.show_progress_every_n_steps == -1:
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
self.job_no += 1
|
||||
|
@ -207,6 +217,7 @@ class State:
|
|||
self.current_latent = None
|
||||
self.current_image = None
|
||||
self.current_image_sampling_step = 0
|
||||
self.id_live_preview = 0
|
||||
self.skipped = False
|
||||
self.interrupted = False
|
||||
self.textinfo = None
|
||||
|
@ -220,12 +231,12 @@ class State:
|
|||
|
||||
devices.torch_gc()
|
||||
|
||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
||||
def set_current_image(self):
|
||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
||||
if not parallel_processing_allowed:
|
||||
return
|
||||
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
def do_set_current_image(self):
|
||||
|
@ -234,16 +245,19 @@ class State:
|
|||
|
||||
import modules.sd_samplers
|
||||
if opts.show_progress_grid:
|
||||
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
|
||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||
else:
|
||||
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
|
||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
def assign_current_image(self, image):
|
||||
self.current_image = image
|
||||
self.id_live_preview += 1
|
||||
|
||||
|
||||
state = State()
|
||||
|
||||
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
|
||||
state.server_start = time.time()
|
||||
|
||||
styles_filename = cmd_opts.styles_file
|
||||
prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
||||
|
@ -252,12 +266,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
|||
|
||||
face_restorers = []
|
||||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
|
||||
self.default = default
|
||||
|
@ -348,7 +356,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
|||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
}))
|
||||
|
||||
|
@ -359,9 +367,11 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('system', "System"), {
|
||||
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
|
@ -374,42 +384,44 @@ options_templates.update(options_section(('training', "Training"), {
|
|||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
||||
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
||||
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
|
||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
||||
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
|
||||
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
|
||||
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
||||
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
|
||||
|
@ -417,15 +429,17 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
|||
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||
|
@ -434,10 +448,23 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
|
||||
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
|
||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
||||
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "Live previews"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
|
@ -452,8 +479,15 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section((None, "Hidden options"), {
|
||||
"disabled_extensions": OptionInfo([], "Disable those extensions"),
|
||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||
}))
|
||||
|
||||
options_templates.update()
|
||||
|
@ -636,3 +670,17 @@ mem_mon.start()
|
|||
def listfiles(dirname):
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
|
||||
return [file for file in filenames if os.path.isfile(file)]
|
||||
|
||||
|
||||
def html_path(filename):
|
||||
return os.path.join(script_path, "html", filename)
|
||||
|
||||
|
||||
def html(filename):
|
||||
path = html_path(filename)
|
||||
|
||||
if os.path.exists(path):
|
||||
with open(path, encoding="utf8") as file:
|
||||
return file.read()
|
||||
|
||||
return ""
|
||||
|
|
23
modules/shared_items.py
Normal file
23
modules/shared_items.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||
|
||||
|
||||
def postprocessing_scripts():
|
||||
import modules.scripts
|
||||
|
||||
return modules.scripts.scripts_postproc.scripts
|
||||
|
||||
|
||||
def sd_vae_items():
|
||||
import modules.sd_vae
|
||||
|
||||
return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
|
||||
|
||||
|
||||
def refresh_vae_list():
|
||||
import modules.sd_vae
|
||||
|
||||
return modules.sd_vae.refresh_vae_list
|
|
@ -40,12 +40,18 @@ def apply_styles_to_prompt(prompt, styles):
|
|||
class StyleDatabase:
|
||||
def __init__(self, path: str):
|
||||
self.no_style = PromptStyle("None", "", "")
|
||||
self.styles = {"None": self.no_style}
|
||||
self.styles = {}
|
||||
self.path = path
|
||||
|
||||
if not os.path.exists(path):
|
||||
self.reload()
|
||||
|
||||
def reload(self):
|
||||
self.styles.clear()
|
||||
|
||||
if not os.path.exists(self.path):
|
||||
return
|
||||
|
||||
with open(path, "r", encoding="utf-8-sig", newline='') as file:
|
||||
with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
|
||||
reader = csv.DictReader(file)
|
||||
for row in reader:
|
||||
# Support loading old CSV format with "name, text"-columns
|
||||
|
|
|
@ -15,7 +15,8 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import math
|
||||
from typing import Optional, NamedTuple, Protocol, List
|
||||
from typing import Optional, NamedTuple, List
|
||||
|
||||
|
||||
def narrow_trunc(
|
||||
input: Tensor,
|
||||
|
@ -25,12 +26,14 @@ def narrow_trunc(
|
|||
) -> Tensor:
|
||||
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
||||
|
||||
|
||||
class AttnChunk(NamedTuple):
|
||||
exp_values: Tensor
|
||||
exp_weights_sum: Tensor
|
||||
max_score: Tensor
|
||||
|
||||
class SummarizeChunk(Protocol):
|
||||
|
||||
class SummarizeChunk:
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
|
@ -38,7 +41,8 @@ class SummarizeChunk(Protocol):
|
|||
value: Tensor,
|
||||
) -> AttnChunk: ...
|
||||
|
||||
class ComputeQueryChunkAttn(Protocol):
|
||||
|
||||
class ComputeQueryChunkAttn:
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
|
@ -46,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol):
|
|||
value: Tensor,
|
||||
) -> Tensor: ...
|
||||
|
||||
|
||||
def _summarize_chunk(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
|
@ -62,10 +67,11 @@ def _summarize_chunk(
|
|||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
exp_weights = torch.exp(attn_weights - max_score)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
|
||||
|
||||
def _query_chunk_attention(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
|
@ -106,6 +112,7 @@ def _query_chunk_attention(
|
|||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||
return all_values / all_weights
|
||||
|
||||
|
||||
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||
def _get_attention_scores_no_kv_chunking(
|
||||
query: Tensor,
|
||||
|
@ -122,13 +129,15 @@ def _get_attention_scores_no_kv_chunking(
|
|||
)
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
|
||||
return hidden_states_slice
|
||||
|
||||
|
||||
class ScannedChunk(NamedTuple):
|
||||
chunk_idx: int
|
||||
attn_chunk: AttnChunk
|
||||
|
||||
|
||||
def efficient_dot_product_attention(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
|
|
|
@ -3,8 +3,10 @@ import numpy as np
|
|||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||
from torchvision import transforms
|
||||
from collections import defaultdict
|
||||
from random import shuffle, choices
|
||||
|
||||
import random
|
||||
import tqdm
|
||||
|
@ -28,13 +30,11 @@ class DatasetEntry:
|
|||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
self.dataset = []
|
||||
|
@ -50,16 +50,18 @@ class PersonalizedBase(Dataset):
|
|||
|
||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
|
||||
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.tag_drop_out = tag_drop_out
|
||||
groups = defaultdict(list)
|
||||
|
||||
print("Preparing dataset...")
|
||||
for path in tqdm.tqdm(self.image_paths):
|
||||
if shared.state.interrupted:
|
||||
raise Exception("interrupted")
|
||||
try:
|
||||
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||
image = Image.open(path).convert('RGB')
|
||||
if not varsize:
|
||||
image = image.resize((width, height), PIL.Image.BICUBIC)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
@ -103,18 +105,25 @@ class PersonalizedBase(Dataset):
|
|||
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
with devices.autocast():
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
|
||||
groups[image.size].append(len(self.dataset))
|
||||
self.dataset.append(entry)
|
||||
del torchdata
|
||||
del latent_dist
|
||||
del latent_sample
|
||||
|
||||
self.length = len(self.dataset)
|
||||
self.groups = list(groups.values())
|
||||
assert self.length > 0, "No images have been found in the dataset."
|
||||
self.batch_size = min(batch_size, self.length)
|
||||
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||
self.latent_sampling_method = latent_sampling_method
|
||||
|
||||
if len(groups) > 1:
|
||||
print("Buckets:")
|
||||
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
|
||||
print(f" {w}x{h}: {len(ids)}")
|
||||
print()
|
||||
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
tags = filename_text.split(',')
|
||||
|
@ -137,9 +146,44 @@ class PersonalizedBase(Dataset):
|
|||
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||
return entry
|
||||
|
||||
|
||||
class GroupedBatchSampler(Sampler):
|
||||
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||
super().__init__(data_source)
|
||||
|
||||
n = len(data_source)
|
||||
self.groups = data_source.groups
|
||||
self.len = n_batch = n // batch_size
|
||||
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
||||
self.base = [int(e) // batch_size for e in expected]
|
||||
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
||||
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __iter__(self):
|
||||
b = self.batch_size
|
||||
|
||||
for g in self.groups:
|
||||
shuffle(g)
|
||||
|
||||
batches = []
|
||||
for g in self.groups:
|
||||
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||
for _ in range(self.n_rand_batches):
|
||||
rand_group = choices(self.groups, self.probs)[0]
|
||||
batches.append(choices(rand_group, k=b))
|
||||
|
||||
shuffle(batches)
|
||||
|
||||
yield from batches
|
||||
|
||||
|
||||
class PersonalizedDataLoader(DataLoader):
|
||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
|
||||
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||
if latent_sampling_method == "random":
|
||||
self.collate_fn = collate_wrapper_random
|
||||
else:
|
||||
|
|
|
@ -76,10 +76,10 @@ def insert_image_data_embed(image, data):
|
|||
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
|
||||
next_size = next_size + ((h*d)-(next_size % (h*d)))
|
||||
|
||||
data_np_low.resize(next_size)
|
||||
data_np_low = np.resize(data_np_low, next_size)
|
||||
data_np_low = data_np_low.reshape((h, -1, d))
|
||||
|
||||
data_np_high.resize(next_size)
|
||||
data_np_high = np.resize(data_np_high, next_size)
|
||||
data_np_high = data_np_high.reshape((h, -1, d))
|
||||
|
||||
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
|
||||
|
|
|
@ -2,7 +2,7 @@ import datetime
|
|||
import json
|
||||
import os
|
||||
|
||||
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
|
||||
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
|
||||
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
|
||||
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
|
||||
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
||||
|
|
|
@ -6,13 +6,12 @@ import sys
|
|||
import tqdm
|
||||
import time
|
||||
|
||||
from modules import shared, images, deepbooru
|
||||
from modules.paths import models_path
|
||||
from modules import paths, shared, images, deepbooru
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.textual_inversion import autocrop
|
||||
|
||||
|
||||
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
|
||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
try:
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
|
@ -20,7 +19,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
|
|||
if process_caption_deepbooru:
|
||||
deepbooru.model.start()
|
||||
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
||||
|
||||
finally:
|
||||
|
||||
|
@ -109,8 +108,30 @@ def split_pic(image, inverse_xy, width, height, overlap_ratio):
|
|||
splitted = image.crop((0, y, to_w, y + to_h))
|
||||
yield splitted
|
||||
|
||||
# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
|
||||
def center_crop(image: Image, w: int, h: int):
|
||||
iw, ih = image.size
|
||||
if ih / h < iw / w:
|
||||
sw = w * ih / h
|
||||
box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
|
||||
else:
|
||||
sh = h * iw / w
|
||||
box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
|
||||
return image.resize((w, h), Image.Resampling.LANCZOS, box)
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
|
||||
|
||||
def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
|
||||
iw, ih = image.size
|
||||
err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
|
||||
wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
|
||||
if minarea <= w * h <= maxarea and err(w, h) <= threshold),
|
||||
key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
|
||||
default=None
|
||||
)
|
||||
return wh and center_crop(image, *wh)
|
||||
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
width = process_width
|
||||
height = process_height
|
||||
src = os.path.abspath(process_src)
|
||||
|
@ -135,7 +156,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
params.process_caption_deepbooru = process_caption_deepbooru
|
||||
params.preprocess_txt_action = preprocess_txt_action
|
||||
|
||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||
pbar = tqdm.tqdm(files)
|
||||
for index, imagefile in enumerate(pbar):
|
||||
params.subindex = 0
|
||||
filename = os.path.join(src, imagefile)
|
||||
try:
|
||||
|
@ -143,6 +165,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
except Exception:
|
||||
continue
|
||||
|
||||
description = f"Preprocessing [Image {index}/{len(files)}]"
|
||||
pbar.set_description(description)
|
||||
shared.state.textinfo = description
|
||||
|
||||
params.src = filename
|
||||
|
||||
existing_caption = None
|
||||
|
@ -172,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
|
||||
dnn_model_path = None
|
||||
try:
|
||||
dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
|
||||
dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
|
||||
except Exception as e:
|
||||
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
|
||||
|
||||
|
@ -189,6 +215,14 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||
save_pic(focal, index, params, existing_caption=existing_caption)
|
||||
process_default_resize = False
|
||||
|
||||
if process_multicrop:
|
||||
cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
||||
if cropped is not None:
|
||||
save_pic(cropped, index, params, existing_caption=existing_caption)
|
||||
else:
|
||||
print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
|
||||
process_default_resize = False
|
||||
|
||||
if process_default_resize:
|
||||
img = images.resize_image(1, img, width, height)
|
||||
save_pic(img, index, params, existing_caption=existing_caption)
|
||||
|
|
|
@ -2,25 +2,43 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import html
|
||||
import datetime
|
||||
import csv
|
||||
import safetensors.torch
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
|
||||
insert_image_data_embed, extract_image_data_embed,
|
||||
caption_image_overlay)
|
||||
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
|
||||
from modules.textual_inversion.logging import save_settings_to_file
|
||||
|
||||
|
||||
TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
|
||||
textual_inversion_templates = {}
|
||||
|
||||
|
||||
def list_textual_inversion_templates():
|
||||
textual_inversion_templates.clear()
|
||||
|
||||
for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
|
||||
for fn in fns:
|
||||
path = os.path.join(root, fn)
|
||||
|
||||
textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
|
||||
|
||||
return textual_inversion_templates
|
||||
|
||||
|
||||
class Embedding:
|
||||
def __init__(self, vec, name, step=None):
|
||||
self.vec = vec
|
||||
|
@ -32,6 +50,7 @@ class Embedding:
|
|||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.optimizer_state_dict = None
|
||||
self.filename = None
|
||||
|
||||
def save(self, filename):
|
||||
embedding_data = {
|
||||
|
@ -93,6 +112,7 @@ class EmbeddingDatabase:
|
|||
self.skipped_embeddings = {}
|
||||
self.expected_shape = -1
|
||||
self.embedding_dirs = {}
|
||||
self.previously_displayed_embeddings = ()
|
||||
|
||||
def add_embedding_dir(self, path):
|
||||
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
||||
|
@ -135,6 +155,8 @@ class EmbeddingDatabase:
|
|||
name = data.get('name', name)
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
elif ext in ['.SAFETENSORS']:
|
||||
data = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
return
|
||||
|
||||
|
@ -162,6 +184,7 @@ class EmbeddingDatabase:
|
|||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
embedding.filename = path
|
||||
|
||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
|
@ -172,7 +195,7 @@ class EmbeddingDatabase:
|
|||
if not os.path.isdir(embdir.path):
|
||||
return
|
||||
|
||||
for root, dirs, fns in os.walk(embdir.path):
|
||||
for root, dirs, fns in os.walk(embdir.path, followlinks=True):
|
||||
for fn in fns:
|
||||
try:
|
||||
fullfn = os.path.join(root, fn)
|
||||
|
@ -206,9 +229,12 @@ class EmbeddingDatabase:
|
|||
self.load_from_dir(embdir)
|
||||
embdir.update()
|
||||
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if len(self.skipped_embeddings) > 0:
|
||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
|
||||
if self.previously_displayed_embeddings != displayed_embeddings:
|
||||
self.previously_displayed_embeddings = displayed_embeddings
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if len(self.skipped_embeddings) > 0:
|
||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||
|
||||
def find_embedding_at_position(self, tokens, offset):
|
||||
token = tokens[offset]
|
||||
|
@ -230,11 +256,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
|||
with devices.autocast():
|
||||
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
||||
|
||||
embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
|
||||
#cond_model expects at least some text, so we provide '*' as backup.
|
||||
embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
|
||||
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||
|
||||
for i in range(num_vectors_per_token):
|
||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||
#Only copy if we provided an init_text, otherwise keep vectors as zeros
|
||||
if init_text:
|
||||
for i in range(num_vectors_per_token):
|
||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
|
@ -273,8 +302,32 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
|||
**values,
|
||||
})
|
||||
|
||||
def tensorboard_setup(log_directory):
|
||||
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
||||
return SummaryWriter(
|
||||
log_dir=os.path.join(log_directory, "tensorboard"),
|
||||
flush_secs=shared.opts.training_tensorboard_flush_every)
|
||||
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
|
||||
tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
|
||||
tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
|
||||
tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
|
||||
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
||||
|
||||
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
||||
tensorboard_writer.add_scalar(tag=tag,
|
||||
scalar_value=value, global_step=step)
|
||||
|
||||
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
||||
# Convert a pil image to a torch tensor
|
||||
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
||||
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
||||
len(pil_image.getbands()))
|
||||
img_tensor = img_tensor.permute((2, 0, 1))
|
||||
|
||||
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
||||
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
assert model_name, f"{name} not selected"
|
||||
assert learn_rate, "Learning rate is empty or 0"
|
||||
assert isinstance(batch_size, int), "Batch size must be integer"
|
||||
|
@ -284,8 +337,9 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
|||
assert data_root, "Dataset directory is empty"
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
assert template_file, "Prompt template file is empty"
|
||||
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
||||
assert template_filename, "Prompt template file not selected"
|
||||
assert template_file, f"Prompt template file {template_filename} not found"
|
||||
assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
|
||||
assert steps, "Max steps is empty or 0"
|
||||
assert isinstance(steps, int), "Max steps must be integer"
|
||||
assert steps > 0, "Max steps must be positive"
|
||||
|
@ -297,10 +351,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
|||
assert log_directory, "Log directory is empty"
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||
template_file = textual_inversion_templates.get(template_filename, None)
|
||||
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||
template_file = template_file.path
|
||||
|
||||
shared.state.job = "train-embedding"
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
|
@ -348,13 +404,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
tensorboard_writer = tensorboard_setup(log_directory)
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
|
||||
|
||||
if shared.opts.save_training_settings_to_txt:
|
||||
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
|
||||
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
|
||||
|
||||
latent_sampling_method = ds.latent_sampling_method
|
||||
|
||||
|
@ -399,6 +458,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
sd_hijack_checkpoint.add()
|
||||
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
|
@ -455,7 +516,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
epoch_num = embedding.step // steps_per_epoch
|
||||
epoch_step = embedding.step % steps_per_epoch
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||
description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
|
||||
pbar.set_description(description)
|
||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
|
@ -505,10 +567,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
shared.state.assign_current_image(image)
|
||||
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
||||
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||
|
@ -526,7 +592,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||
footer_mid = '[{}]'.format(checkpoint.shorthash)
|
||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||
|
||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||
|
@ -559,16 +625,18 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
pbar.close()
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
sd_hijack_checkpoint.remove()
|
||||
|
||||
return embedding, filename
|
||||
|
||||
|
||||
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||
old_embedding_name = embedding.name
|
||||
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
||||
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
|
||||
try:
|
||||
embedding.sd_checkpoint = checkpoint.hash
|
||||
embedding.sd_checkpoint = checkpoint.shorthash
|
||||
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||
if remove_cached_checksum:
|
||||
embedding.cached_checksum = None
|
||||
|
|
35
modules/timer.py
Normal file
35
modules/timer.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import time
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.start = time.time()
|
||||
self.records = {}
|
||||
self.total = 0
|
||||
|
||||
def elapsed(self):
|
||||
end = time.time()
|
||||
res = end - self.start
|
||||
self.start = end
|
||||
return res
|
||||
|
||||
def record(self, category, extra_time=0):
|
||||
e = self.elapsed()
|
||||
if category not in self.records:
|
||||
self.records[category] = 0
|
||||
|
||||
self.records[category] += e + extra_time
|
||||
self.total += e + extra_time
|
||||
|
||||
def summary(self):
|
||||
res = f"{self.total:.1f}s"
|
||||
|
||||
additions = [x for x in self.records.items() if x[1] >= 0.1]
|
||||
if not additions:
|
||||
return res
|
||||
|
||||
res += " ("
|
||||
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
|
||||
res += ")"
|
||||
|
||||
return res
|
|
@ -8,13 +8,13 @@ import modules.processing as processing
|
|||
from modules.ui import plaintext_to_html
|
||||
|
||||
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
||||
prompt=prompt,
|
||||
styles=[prompt_style, prompt_style2],
|
||||
styles=prompt_styles,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
subseed=subseed,
|
||||
|
|
940
modules/ui.py
940
modules/ui.py
File diff suppressed because it is too large
Load Diff
202
modules/ui_common.py
Normal file
202
modules/ui_common.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
import json
|
||||
import html
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import gradio as gr
|
||||
import subprocess as sp
|
||||
|
||||
from modules import call_queue, shared
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
import modules.images
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
|
||||
|
||||
def update_generation_info(generation_info, html_info, img_index):
|
||||
try:
|
||||
generation_info = json.loads(generation_info)
|
||||
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
|
||||
return html_info, gr.update()
|
||||
return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
|
||||
except Exception:
|
||||
pass
|
||||
# if the json parse or anything else fails, just return the old html_info
|
||||
return html_info, gr.update()
|
||||
|
||||
|
||||
def plaintext_to_html(text):
|
||||
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
||||
return text
|
||||
|
||||
|
||||
def save_files(js_data, images, do_make_zip, index):
|
||||
import csv
|
||||
filenames = []
|
||||
fullfns = []
|
||||
|
||||
#quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
|
||||
class MyObject:
|
||||
def __init__(self, d=None):
|
||||
if d is not None:
|
||||
for key, value in d.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
data = json.loads(js_data)
|
||||
|
||||
p = MyObject(data)
|
||||
path = shared.opts.outdir_save
|
||||
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
|
||||
extension: str = shared.opts.samples_format
|
||||
start_index = 0
|
||||
|
||||
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
||||
|
||||
images = [images[index]]
|
||||
start_index = index
|
||||
|
||||
os.makedirs(shared.opts.outdir_save, exist_ok=True)
|
||||
|
||||
with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
||||
at_start = file.tell() == 0
|
||||
writer = csv.writer(file)
|
||||
if at_start:
|
||||
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
||||
|
||||
for image_index, filedata in enumerate(images, start_index):
|
||||
image = image_from_url_text(filedata)
|
||||
|
||||
is_grid = image_index < p.index_of_first_image
|
||||
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
||||
|
||||
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
||||
|
||||
filename = os.path.relpath(fullfn, path)
|
||||
filenames.append(filename)
|
||||
fullfns.append(fullfn)
|
||||
if txt_fullfn:
|
||||
filenames.append(os.path.basename(txt_fullfn))
|
||||
fullfns.append(txt_fullfn)
|
||||
|
||||
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
||||
|
||||
# Make Zip
|
||||
if do_make_zip:
|
||||
zip_filepath = os.path.join(path, "images.zip")
|
||||
|
||||
from zipfile import ZipFile
|
||||
with ZipFile(zip_filepath, "w") as zip_file:
|
||||
for i in range(len(fullfns)):
|
||||
with open(fullfns[i], mode="rb") as f:
|
||||
zip_file.writestr(filenames[i], f.read())
|
||||
fullfns.insert(0, zip_filepath)
|
||||
|
||||
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
def create_output_panel(tabname, outdir):
|
||||
from modules import shared
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
|
||||
def open_folder(f):
|
||||
if not os.path.exists(f):
|
||||
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
||||
return
|
||||
elif not os.path.isdir(f):
|
||||
print(f"""
|
||||
WARNING
|
||||
An open_folder request was made with an argument that is not a folder.
|
||||
This could be an error or a malicious attempt to run code on your computer.
|
||||
Requested path was: {f}
|
||||
""", file=sys.stderr)
|
||||
return
|
||||
|
||||
if not shared.cmd_opts.hide_ui_dir_config:
|
||||
path = os.path.normpath(f)
|
||||
if platform.system() == "Windows":
|
||||
os.startfile(path)
|
||||
elif platform.system() == "Darwin":
|
||||
sp.Popen(["open", path])
|
||||
elif "microsoft-standard-WSL2" in platform.uname().release:
|
||||
sp.Popen(["wsl-open", path])
|
||||
else:
|
||||
sp.Popen(["xdg-open", path])
|
||||
|
||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
|
||||
|
||||
generation_info = None
|
||||
with gr.Column():
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}"):
|
||||
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
|
||||
|
||||
if tabname != "extras":
|
||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
||||
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
|
||||
|
||||
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
|
||||
|
||||
open_folder_button.click(
|
||||
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
if tabname != "extras":
|
||||
with gr.Row():
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
||||
if tabname == 'txt2img' or tabname == 'img2img':
|
||||
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
||||
generation_info_button.click(
|
||||
fn=update_generation_info,
|
||||
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
||||
inputs=[generation_info, html_info, html_info],
|
||||
outputs=[html_info, html_info],
|
||||
)
|
||||
|
||||
save.click(
|
||||
fn=call_queue.wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
result_gallery,
|
||||
html_info,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_log,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
save_zip.click(
|
||||
fn=call_queue.wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
result_gallery,
|
||||
html_info,
|
||||
html_info,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_log,
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
|
||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
|
@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent):
|
|||
return "button"
|
||||
|
||||
|
||||
class ToolButtonTop(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool-top", **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
class FormRow(gr.Row, gr.components.FormComponent):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
|
@ -31,3 +41,18 @@ class FormHTML(gr.HTML, gr.components.FormComponent):
|
|||
def get_block_name(self):
|
||||
return "html"
|
||||
|
||||
|
||||
class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
|
||||
"""Same as gr.ColorPicker but fits inside gradio forms"""
|
||||
|
||||
def get_block_name(self):
|
||||
return "colorpicker"
|
||||
|
||||
|
||||
class DropdownMulti(gr.Dropdown):
|
||||
"""Same as gr.Dropdown but always multiselect"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(multiselect=True, **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "dropdown"
|
||||
|
|
|
@ -13,7 +13,7 @@ import shutil
|
|||
import errno
|
||||
|
||||
from modules import extensions, shared, paths
|
||||
|
||||
from modules.call_queue import wrap_gradio_gpu_call
|
||||
|
||||
available_extensions = {"extensions": []}
|
||||
|
||||
|
@ -50,12 +50,17 @@ def apply_and_restart(disable_list, update_list):
|
|||
shared.state.need_restart = True
|
||||
|
||||
|
||||
def check_updates():
|
||||
def check_updates(id_task, disable_list):
|
||||
check_access()
|
||||
|
||||
for ext in extensions.extensions:
|
||||
if ext.remote is None:
|
||||
continue
|
||||
disabled = json.loads(disable_list)
|
||||
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
|
||||
|
||||
exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
|
||||
shared.state.job_count = len(exts)
|
||||
|
||||
for ext in exts:
|
||||
shared.state.textinfo = ext.name
|
||||
|
||||
try:
|
||||
ext.check_updates()
|
||||
|
@ -63,7 +68,9 @@ def check_updates():
|
|||
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
return extension_table()
|
||||
shared.state.nextjob()
|
||||
|
||||
return extension_table(), ""
|
||||
|
||||
|
||||
def extension_table():
|
||||
|
@ -132,7 +139,7 @@ def install_extension_from_url(dirname, url):
|
|||
normalized_url = normalize_git_url(url)
|
||||
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
|
||||
|
||||
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
|
||||
tmpdir = os.path.join(paths.data_path, "tmp", dirname)
|
||||
|
||||
try:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
@ -273,12 +280,13 @@ def create_ui():
|
|||
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
||||
with gr.TabItem("Installed"):
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(elem_id="extensions_installed_top"):
|
||||
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||
check = gr.Button(value="Check for updates")
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
|
||||
|
||||
info = gr.HTML()
|
||||
extensions_table = gr.HTML(lambda: extension_table())
|
||||
|
||||
apply.click(
|
||||
|
@ -289,10 +297,10 @@ def create_ui():
|
|||
)
|
||||
|
||||
check.click(
|
||||
fn=check_updates,
|
||||
fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),
|
||||
_js="extensions_check",
|
||||
inputs=[],
|
||||
outputs=[extensions_table],
|
||||
inputs=[info, extensions_disabled_list],
|
||||
outputs=[extensions_table, info],
|
||||
)
|
||||
|
||||
with gr.TabItem("Available"):
|
||||
|
|
244
modules/ui_extra_networks.py
Normal file
244
modules/ui_extra_networks.py
Normal file
|
@ -0,0 +1,244 @@
|
|||
import glob
|
||||
import os.path
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
from modules import shared
|
||||
import gradio as gr
|
||||
import json
|
||||
import html
|
||||
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
|
||||
extra_pages = []
|
||||
allowed_dirs = set()
|
||||
|
||||
|
||||
def register_page(page):
|
||||
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
||||
|
||||
extra_pages.append(page)
|
||||
allowed_dirs.clear()
|
||||
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
|
||||
|
||||
|
||||
def add_pages_to_demo(app):
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
if os.path.splitext(filename)[1].lower() != ".png":
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
|
||||
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
||||
|
||||
|
||||
class ExtraNetworksPage:
|
||||
def __init__(self, title):
|
||||
self.title = title
|
||||
self.name = title.lower()
|
||||
self.card_page = shared.html("extra-networks-card.html")
|
||||
self.allow_negative_prompt = False
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
def link_preview(self, filename):
|
||||
return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
|
||||
|
||||
def search_terms_from_path(self, filename, possible_directories=None):
|
||||
abspath = os.path.abspath(filename)
|
||||
|
||||
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
|
||||
parentdir = os.path.abspath(parentdir)
|
||||
if abspath.startswith(parentdir):
|
||||
return abspath[len(parentdir):].replace('\\', '/')
|
||||
|
||||
return ""
|
||||
|
||||
def create_html(self, tabname):
|
||||
view = shared.opts.extra_networks_default_view
|
||||
items_html = ''
|
||||
|
||||
subdirs = {}
|
||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
|
||||
if not os.path.isdir(x):
|
||||
continue
|
||||
|
||||
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
||||
while subdir.startswith("/"):
|
||||
subdir = subdir[1:]
|
||||
|
||||
subdirs[subdir] = 1
|
||||
|
||||
if subdirs:
|
||||
subdirs = {"": 1, **subdirs}
|
||||
|
||||
subdirs_html = "".join([f"""
|
||||
<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
||||
{html.escape(subdir if subdir!="" else "all")}
|
||||
</button>
|
||||
""" for subdir in subdirs])
|
||||
|
||||
for item in self.list_items():
|
||||
items_html += self.create_html_for_item(item, tabname)
|
||||
|
||||
if items_html == '':
|
||||
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
||||
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||
|
||||
res = f"""
|
||||
<div id='{tabname}_{self.name}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
|
||||
{subdirs_html}
|
||||
</div>
|
||||
<div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
|
||||
{items_html}
|
||||
</div>
|
||||
"""
|
||||
|
||||
return res
|
||||
|
||||
def list_items(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return []
|
||||
|
||||
def create_html_for_item(self, item, tabname):
|
||||
preview = item.get("preview", None)
|
||||
|
||||
onclick = item.get("onclick", None)
|
||||
if onclick is None:
|
||||
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
||||
|
||||
args = {
|
||||
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
||||
"prompt": item.get("prompt", None),
|
||||
"tabname": json.dumps(tabname),
|
||||
"local_preview": json.dumps(item["local_preview"]),
|
||||
"name": item["name"],
|
||||
"card_clicked": onclick,
|
||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
||||
"search_term": item.get("search_term", ""),
|
||||
}
|
||||
|
||||
return self.card_page.format(**args)
|
||||
|
||||
|
||||
def intialize():
|
||||
extra_pages.clear()
|
||||
|
||||
|
||||
class ExtraNetworksUi:
|
||||
def __init__(self):
|
||||
self.pages = None
|
||||
self.stored_extra_pages = None
|
||||
|
||||
self.button_save_preview = None
|
||||
self.preview_target_filename = None
|
||||
|
||||
self.tabname = None
|
||||
|
||||
|
||||
def pages_in_preferred_order(pages):
|
||||
tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
|
||||
|
||||
def tab_name_score(name):
|
||||
name = name.lower()
|
||||
for i, possible_match in enumerate(tab_order):
|
||||
if possible_match in name:
|
||||
return i
|
||||
|
||||
return len(pages)
|
||||
|
||||
tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
|
||||
|
||||
return sorted(pages, key=lambda x: tab_scores[x.name])
|
||||
|
||||
|
||||
def create_ui(container, button, tabname):
|
||||
ui = ExtraNetworksUi()
|
||||
ui.pages = []
|
||||
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
||||
ui.tabname = tabname
|
||||
|
||||
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title):
|
||||
page_elem = gr.HTML(page.create_html(ui.tabname))
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
||||
button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
|
||||
|
||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||
|
||||
def toggle_visibility(is_visible):
|
||||
is_visible = not is_visible
|
||||
return is_visible, gr.update(visible=is_visible)
|
||||
|
||||
state_visible = gr.State(value=False)
|
||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
|
||||
button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
|
||||
|
||||
def refresh():
|
||||
res = []
|
||||
|
||||
for pg in ui.stored_extra_pages:
|
||||
pg.refresh()
|
||||
res.append(pg.create_html(ui.tabname))
|
||||
|
||||
return res
|
||||
|
||||
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
||||
|
||||
return ui
|
||||
|
||||
|
||||
def path_is_parent(parent_path, child_path):
|
||||
parent_path = os.path.abspath(parent_path)
|
||||
child_path = os.path.abspath(child_path)
|
||||
|
||||
return child_path.startswith(parent_path)
|
||||
|
||||
|
||||
def setup_ui(ui, gallery):
|
||||
def save_preview(index, images, filename):
|
||||
if len(images) == 0:
|
||||
print("There is no image in gallery to save as a preview.")
|
||||
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||
|
||||
index = int(index)
|
||||
index = 0 if index < 0 else index
|
||||
index = len(images) - 1 if index >= len(images) else index
|
||||
|
||||
img_info = images[index if index >= 0 else 0]
|
||||
image = image_from_url_text(img_info)
|
||||
|
||||
is_allowed = False
|
||||
for extra_page in ui.stored_extra_pages:
|
||||
if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
|
||||
is_allowed = True
|
||||
break
|
||||
|
||||
assert is_allowed, f'writing to {filename} is not allowed'
|
||||
|
||||
image.save(filename)
|
||||
|
||||
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||
|
||||
ui.button_save_preview.click(
|
||||
fn=save_preview,
|
||||
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
|
||||
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
|
||||
outputs=[*ui.pages]
|
||||
)
|
||||
|
38
modules/ui_extra_networks_checkpoints.py
Normal file
38
modules/ui_extra_networks_checkpoints.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import html
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
|
||||
from modules import shared, ui_extra_networks, sd_models
|
||||
|
||||
|
||||
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Checkpoints')
|
||||
|
||||
def refresh(self):
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
def list_items(self):
|
||||
for name, checkpoint in sd_models.checkpoints_list.items():
|
||||
path, ext = os.path.splitext(checkpoint.filename)
|
||||
previews = [path + ".png", path + ".preview.png"]
|
||||
|
||||
preview = None
|
||||
for file in previews:
|
||||
if os.path.isfile(file):
|
||||
preview = self.link_preview(file)
|
||||
break
|
||||
|
||||
yield {
|
||||
"name": checkpoint.name_for_extra,
|
||||
"filename": path,
|
||||
"preview": preview,
|
||||
"search_term": self.search_terms_from_path(checkpoint.filename),
|
||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
|
||||
"local_preview": path + ".png",
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||
|
36
modules/ui_extra_networks_hypernets.py
Normal file
36
modules/ui_extra_networks_hypernets.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
from modules import shared, ui_extra_networks
|
||||
|
||||
|
||||
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Hypernetworks')
|
||||
|
||||
def refresh(self):
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
def list_items(self):
|
||||
for name, path in shared.hypernetworks.items():
|
||||
path, ext = os.path.splitext(path)
|
||||
previews = [path + ".png", path + ".preview.png"]
|
||||
|
||||
preview = None
|
||||
for file in previews:
|
||||
if os.path.isfile(file):
|
||||
preview = self.link_preview(file)
|
||||
break
|
||||
|
||||
yield {
|
||||
"name": name,
|
||||
"filename": path,
|
||||
"preview": preview,
|
||||
"search_term": self.search_terms_from_path(path),
|
||||
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||
"local_preview": path + ".png",
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [shared.cmd_opts.hypernetwork_dir]
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user