Add deterministic training

This commit is contained in:
Shondoit 2023-01-18 19:50:22 +01:00
parent ea9bd9fc74
commit 114a8a2f62
4 changed files with 54 additions and 0 deletions

View File

@ -349,6 +349,11 @@ def tests(test_dir):
def start():
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
#Do this before Torch is loaded
if '--deterministic-training' in sys.argv:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import webui
if '--nowebui' in sys.argv:
webui.api_only()

View File

@ -6,9 +6,11 @@ import os
import sys
import traceback
import inspect
import random
import modules.textual_inversion.dataset
import torch
import numpy as np
import tqdm
from einops import rearrange, repeat
from ldm.util import default
@ -552,6 +554,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
if shared.cmd_opts.deterministic_training:
old_deterministic_state = set_deterministic()
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
@ -635,6 +640,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if shared.state.interrupted:
break
if shared.cmd_opts.deterministic_training:
set_deterministic_seed(embedding.step * gradient_step + j)
if clip_grad:
clip_grad_sched.step(hypernetwork.step)
@ -775,6 +783,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
if shared.cmd_opts.deterministic_training:
reset_deterministic(old_deterministic_state)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:

View File

@ -61,6 +61,7 @@ parser.add_argument("--clip-models-path", type=str, help="Path to directory with
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deterministic-training", action='store_true', help="Enable deterministic training of Hypernetworks and Textual Inversion. (Same settings and dataset will give the same result)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")

View File

@ -2,9 +2,11 @@ import os
import sys
import traceback
import inspect
import random
from collections import namedtuple
import torch
import numpy as np
import tqdm
import html
import datetime
@ -351,6 +353,32 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
def set_deterministic():
state = {
'use_deterministic_algorithms': torch.are_deterministic_algorithms_enabled(),
'cudnn_deterministic': torch.backends.cudnn.deterministic,
'cudnn_benchmarks': torch.backends.cudnn.benchmark,
}
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_deterministic_seed(0)
return state
def reset_deterministic(state):
torch.use_deterministic_algorithms(state['use_deterministic_algorithms'])
torch.backends.cudnn.deterministic = state['cudnn_deterministic']
torch.backends.cudnn.benchmark = state['cudnn_benchmarks']
def set_deterministic_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
@ -408,6 +436,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if shared.opts.training_enable_tensorboard:
tensorboard_writer = tensorboard_setup(log_directory)
if shared.cmd_opts.deterministic_training:
old_deterministic_state = set_deterministic()
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
@ -475,6 +506,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if shared.state.interrupted:
break
if shared.cmd_opts.deterministic_training:
set_deterministic_seed(embedding.step * gradient_step + j)
if clip_grad:
clip_grad_sched.step(embedding.step)
@ -627,6 +661,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
shared.parallel_processing_allowed = old_parallel_processing_allowed
sd_hijack_checkpoint.remove()
if shared.cmd_opts.deterministic_training:
reset_deterministic(old_deterministic_state)
return embedding, filename