From 114a8a2f62c31cea890ce8a30d2a9cda5c2ab5a5 Mon Sep 17 00:00:00 2001 From: Shondoit Date: Wed, 18 Jan 2023 19:50:22 +0100 Subject: [PATCH] Add deterministic training --- launch.py | 5 +++ modules/hypernetworks/hypernetwork.py | 11 ++++++ modules/shared.py | 1 + .../textual_inversion/textual_inversion.py | 37 +++++++++++++++++++ 4 files changed, 54 insertions(+) diff --git a/launch.py b/launch.py index 9fd766d1..770c337e 100644 --- a/launch.py +++ b/launch.py @@ -349,6 +349,11 @@ def tests(test_dir): def start(): print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") + + #Do this before Torch is loaded + if '--deterministic-training' in sys.argv: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + import webui if '--nowebui' in sys.argv: webui.api_only() diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 825a93b2..cbe70d2c 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -6,9 +6,11 @@ import os import sys import traceback import inspect +import random import modules.textual_inversion.dataset import torch +import numpy as np import tqdm from einops import rearrange, repeat from ldm.util import default @@ -552,6 +554,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + if shared.cmd_opts.deterministic_training: + old_deterministic_state = set_deterministic() + pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) @@ -635,6 +640,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi if shared.state.interrupted: break + if shared.cmd_opts.deterministic_training: + set_deterministic_seed(embedding.step * gradient_step + j) + if clip_grad: clip_grad_sched.step(hypernetwork.step) @@ -775,6 +783,9 @@ Last saved image: {html.escape(last_saved_image)}
+ if shared.cmd_opts.deterministic_training: + reset_deterministic(old_deterministic_state) + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') hypernetwork.optimizer_name = optimizer_name if shared.opts.save_optimizer_state: diff --git a/modules/shared.py b/modules/shared.py index 79fbf724..df96fbd9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -61,6 +61,7 @@ parser.add_argument("--clip-models-path", type=str, help="Path to directory with parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") +parser.add_argument("--deterministic-training", action='store_true', help="Enable deterministic training of Hypernetworks and Textual Inversion. (Same settings and dataset will give the same result)") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index a1a406c2..042827ec 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -2,9 +2,11 @@ import os import sys import traceback import inspect +import random from collections import namedtuple import torch +import numpy as np import tqdm import html import datetime @@ -351,6 +353,32 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert log_directory, "Log directory is empty" +def set_deterministic(): + state = { + 'use_deterministic_algorithms': torch.are_deterministic_algorithms_enabled(), + 'cudnn_deterministic': torch.backends.cudnn.deterministic, + 'cudnn_benchmarks': torch.backends.cudnn.benchmark, + } + + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + set_deterministic_seed(0) + + return state + +def reset_deterministic(state): + torch.use_deterministic_algorithms(state['use_deterministic_algorithms']) + torch.backends.cudnn.deterministic = state['cudnn_deterministic'] + torch.backends.cudnn.benchmark = state['cudnn_benchmarks'] + +def set_deterministic_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 @@ -408,6 +436,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st if shared.opts.training_enable_tensorboard: tensorboard_writer = tensorboard_setup(log_directory) + if shared.cmd_opts.deterministic_training: + old_deterministic_state = set_deterministic() + pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) @@ -475,6 +506,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st if shared.state.interrupted: break + if shared.cmd_opts.deterministic_training: + set_deterministic_seed(embedding.step * gradient_step + j) + if clip_grad: clip_grad_sched.step(embedding.step) @@ -627,6 +661,9 @@ Last saved image: {html.escape(last_saved_image)}
shared.parallel_processing_allowed = old_parallel_processing_allowed sd_hijack_checkpoint.remove() + if shared.cmd_opts.deterministic_training: + reset_deterministic(old_deterministic_state) + return embedding, filename