Add deterministic training
This commit is contained in:
parent
ea9bd9fc74
commit
114a8a2f62
|
@ -349,6 +349,11 @@ def tests(test_dir):
|
|||
|
||||
def start():
|
||||
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
|
||||
|
||||
#Do this before Torch is loaded
|
||||
if '--deterministic-training' in sys.argv:
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
||||
|
||||
import webui
|
||||
if '--nowebui' in sys.argv:
|
||||
webui.api_only()
|
||||
|
|
|
@ -6,9 +6,11 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
import inspect
|
||||
import random
|
||||
|
||||
import modules.textual_inversion.dataset
|
||||
import torch
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
|
@ -552,6 +554,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
|
||||
if shared.cmd_opts.deterministic_training:
|
||||
old_deterministic_state = set_deterministic()
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
|
||||
|
@ -635,6 +640,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
if shared.cmd_opts.deterministic_training:
|
||||
set_deterministic_seed(embedding.step * gradient_step + j)
|
||||
|
||||
if clip_grad:
|
||||
clip_grad_sched.step(hypernetwork.step)
|
||||
|
||||
|
@ -775,6 +783,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
|
||||
|
||||
|
||||
if shared.cmd_opts.deterministic_training:
|
||||
reset_deterministic(old_deterministic_state)
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
if shared.opts.save_optimizer_state:
|
||||
|
|
|
@ -61,6 +61,7 @@ parser.add_argument("--clip-models-path", type=str, help="Path to directory with
|
|||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deterministic-training", action='store_true', help="Enable deterministic training of Hypernetworks and Textual Inversion. (Same settings and dataset will give the same result)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
|
|
|
@ -2,9 +2,11 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
import inspect
|
||||
import random
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import html
|
||||
import datetime
|
||||
|
@ -351,6 +353,32 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
|||
assert log_directory, "Log directory is empty"
|
||||
|
||||
|
||||
def set_deterministic():
|
||||
state = {
|
||||
'use_deterministic_algorithms': torch.are_deterministic_algorithms_enabled(),
|
||||
'cudnn_deterministic': torch.backends.cudnn.deterministic,
|
||||
'cudnn_benchmarks': torch.backends.cudnn.benchmark,
|
||||
}
|
||||
|
||||
torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
set_deterministic_seed(0)
|
||||
|
||||
return state
|
||||
|
||||
def reset_deterministic(state):
|
||||
torch.use_deterministic_algorithms(state['use_deterministic_algorithms'])
|
||||
torch.backends.cudnn.deterministic = state['cudnn_deterministic']
|
||||
torch.backends.cudnn.benchmark = state['cudnn_benchmarks']
|
||||
|
||||
def set_deterministic_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
|
@ -408,6 +436,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
if shared.opts.training_enable_tensorboard:
|
||||
tensorboard_writer = tensorboard_setup(log_directory)
|
||||
|
||||
if shared.cmd_opts.deterministic_training:
|
||||
old_deterministic_state = set_deterministic()
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
|
||||
|
@ -475,6 +506,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
if shared.cmd_opts.deterministic_training:
|
||||
set_deterministic_seed(embedding.step * gradient_step + j)
|
||||
|
||||
if clip_grad:
|
||||
clip_grad_sched.step(embedding.step)
|
||||
|
||||
|
@ -627,6 +661,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
sd_hijack_checkpoint.remove()
|
||||
|
||||
if shared.cmd_opts.deterministic_training:
|
||||
reset_deterministic(old_deterministic_state)
|
||||
|
||||
return embedding, filename
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user