This commit is contained in:
Shondoit 2023-02-08 12:05:27 +00:00 committed by GitHub
commit 01b8e6523a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 0 deletions

View File

@ -349,6 +349,11 @@ def tests(test_dir):
def start(): def start():
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
#Do this before Torch is loaded
if '--deterministic-training' in sys.argv:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import webui import webui
if '--nowebui' in sys.argv: if '--nowebui' in sys.argv:
webui.api_only() webui.api_only()

View File

@ -6,9 +6,11 @@ import os
import sys import sys
import traceback import traceback
import inspect import inspect
import random
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
import torch import torch
import numpy as np
import tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import default from ldm.util import default
@ -552,6 +554,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
if shared.cmd_opts.deterministic_training:
old_deterministic_state = set_deterministic()
pin_memory = shared.opts.pin_memory pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
@ -635,6 +640,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if shared.state.interrupted: if shared.state.interrupted:
break break
if shared.cmd_opts.deterministic_training:
set_deterministic_seed(embedding.step * gradient_step + j)
if clip_grad: if clip_grad:
clip_grad_sched.step(hypernetwork.step) clip_grad_sched.step(hypernetwork.step)
@ -775,6 +783,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
if shared.cmd_opts.deterministic_training:
reset_deterministic(old_deterministic_state)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state: if shared.opts.save_optimizer_state:

View File

@ -61,6 +61,7 @@ parser.add_argument("--clip-models-path", type=str, help="Path to directory with
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deterministic-training", action='store_true', help="Enable deterministic training of Hypernetworks and Textual Inversion. (Same settings and dataset will give the same result)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")

View File

@ -2,9 +2,11 @@ import os
import sys import sys
import traceback import traceback
import inspect import inspect
import random
from collections import namedtuple from collections import namedtuple
import torch import torch
import numpy as np
import tqdm import tqdm
import html import html
import datetime import datetime
@ -351,6 +353,32 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty" assert log_directory, "Log directory is empty"
def set_deterministic():
state = {
'use_deterministic_algorithms': torch.are_deterministic_algorithms_enabled(),
'cudnn_deterministic': torch.backends.cudnn.deterministic,
'cudnn_benchmarks': torch.backends.cudnn.benchmark,
}
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_deterministic_seed(0)
return state
def reset_deterministic(state):
torch.use_deterministic_algorithms(state['use_deterministic_algorithms'])
torch.backends.cudnn.deterministic = state['cudnn_deterministic']
torch.backends.cudnn.benchmark = state['cudnn_benchmarks']
def set_deterministic_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0 save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
@ -408,6 +436,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if shared.opts.training_enable_tensorboard: if shared.opts.training_enable_tensorboard:
tensorboard_writer = tensorboard_setup(log_directory) tensorboard_writer = tensorboard_setup(log_directory)
if shared.cmd_opts.deterministic_training:
old_deterministic_state = set_deterministic()
pin_memory = shared.opts.pin_memory pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
@ -475,6 +506,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if shared.state.interrupted: if shared.state.interrupted:
break break
if shared.cmd_opts.deterministic_training:
set_deterministic_seed(embedding.step * gradient_step + j)
if clip_grad: if clip_grad:
clip_grad_sched.step(embedding.step) clip_grad_sched.step(embedding.step)
@ -627,6 +661,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
shared.parallel_processing_allowed = old_parallel_processing_allowed shared.parallel_processing_allowed = old_parallel_processing_allowed
sd_hijack_checkpoint.remove() sd_hijack_checkpoint.remove()
if shared.cmd_opts.deterministic_training:
reset_deterministic(old_deterministic_state)
return embedding, filename return embedding, filename