diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index ea3f1db9..300d3975 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -39,7 +39,7 @@ class HypernetworkModule(torch.nn.Module): activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', - add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False): + add_layer_norm=False, activate_output=False, dropout_structure=None): super().__init__() assert layer_structure is not None, "layer_structure must not be None" @@ -64,9 +64,12 @@ class HypernetworkModule(torch.nn.Module): if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - # Add dropout except last layer - if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2): - linears.append(torch.nn.Dropout(p=0.3)) + # Everything should be now parsed into dropout structure, and applied here. + # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0. + if dropout_structure is not None and dropout_structure[i+1] > 0: + assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!" + linears.append(torch.nn.Dropout(p=dropout_structure[i+1])) + # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0]. self.linear = torch.nn.Sequential(*linears) @@ -113,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x def forward(self, x): - return x + self.linear(x) * self.multiplier + return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1) def trainables(self): layer_structure = [] @@ -126,6 +129,21 @@ class HypernetworkModule(torch.nn.Module): def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength +#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check. +def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout): + if layer_structure is None: + layer_structure = [1, 2, 1] + if not use_dropout: + return [0] * len(layer_structure) + dropout_values = [0] + dropout_values.extend([0.3] * (len(layer_structure) - 3)) + if last_layer_dropout: + dropout_values.append(0.3) + else: + dropout_values.append(0) + dropout_values.append(0) + return dropout_values + class Hypernetwork: filename = None @@ -144,18 +162,22 @@ class Hypernetwork: self.add_layer_norm = add_layer_norm self.use_dropout = use_dropout self.activate_output = activate_output - self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True + self.last_layer_dropout = kwargs.get('last_layer_dropout', True) + self.dropout_structure = kwargs.get('dropout_structure', None) + if self.dropout_structure is None: + self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) self.optimizer_name = None self.optimizer_state_dict = None + self.optional_info = None for size in enable_sizes or []: self.layers[size] = ( HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure), ) - self.eval_mode() + self.eval() def weights(self): res = [] @@ -164,14 +186,14 @@ class Hypernetwork: res += layer.parameters() return res - def train_mode(self): + def train(self, mode=True): for k, layers in self.layers.items(): for layer in layers: - layer.train() + layer.train(mode=mode) for param in layer.parameters(): - param.requires_grad = True + param.requires_grad = mode - def eval_mode(self): + def eval(self): for k, layers in self.layers.items(): for layer in layers: layer.eval() @@ -191,11 +213,13 @@ class Hypernetwork: state_dict['activation_func'] = self.activation_func state_dict['is_layer_norm'] = self.add_layer_norm state_dict['weight_initialization'] = self.weight_init - state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['activate_output'] = self.activate_output - state_dict['last_layer_dropout'] = self.last_layer_dropout + state_dict['use_dropout'] = self.use_dropout + state_dict['dropout_structure'] = self.dropout_structure + state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout + state_dict['optional_info'] = self.optional_info if self.optional_info else None if self.optimizer_name is not None: optimizer_saved_dict['optimizer_name'] = self.optimizer_name @@ -215,43 +239,56 @@ class Hypernetwork: self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) print(self.layer_structure) + optional_info = state_dict.get('optional_info', None) + if optional_info is not None: + print(f"INFO:\n {optional_info}\n") + self.optional_info = optional_info self.activation_func = state_dict.get('activation_func', None) print(f"Activation function is {self.activation_func}") self.weight_init = state_dict.get('weight_initialization', 'Normal') print(f"Weight initialization is {self.weight_init}") self.add_layer_norm = state_dict.get('is_layer_norm', False) print(f"Layer norm is set to {self.add_layer_norm}") - self.use_dropout = state_dict.get('use_dropout', False) + self.dropout_structure = state_dict.get('dropout_structure', None) + self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False) print(f"Dropout usage is set to {self.use_dropout}" ) self.activate_output = state_dict.get('activate_output', True) print(f"Activate last layer is set to {self.activate_output}") self.last_layer_dropout = state_dict.get('last_layer_dropout', False) + # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0. + if self.dropout_structure is None: + print("Using previous dropout structure") + self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) + print(f"Dropout structure is set to {self.dropout_structure}") optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} - self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') - print(f"Optimizer name is {self.optimizer_name}") + if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None): self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) else: self.optimizer_state_dict = None if self.optimizer_state_dict: + self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') print("Loaded existing optimizer from checkpoint") + print(f"Optimizer name is {self.optimizer_name}") else: + self.optimizer_name = "AdamW" print("No saved optimizer exists in checkpoint") for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, self.dropout_structure), HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, self.dropout_structure), ) self.name = state_dict.get('name', self.name) self.step = state_dict.get('step', 0) self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) + self.eval() def list_hypernetworks(path): @@ -379,9 +416,10 @@ def report_statistics(loss_info:dict): print(e) -def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) + assert name, "Name cannot be empty!" fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") if not overwrite_old: @@ -390,6 +428,11 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, if type(layer_structure) == str: layer_structure = [float(x.strip()) for x in layer_structure.split(",")] + if use_dropout and dropout_structure and type(dropout_structure) == str: + dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")] + else: + dropout_structure = [0] * len(layer_structure) + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( name=name, enable_sizes=[int(x) for x in enable_sizes], @@ -398,6 +441,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, weight_init=weight_init, add_layer_norm=add_layer_norm, use_dropout=use_dropout, + dropout_structure=dropout_structure ) hypernet.save(fn) @@ -480,7 +524,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, shared.sd_model.first_stage_model.to(devices.cpu) weights = hypernetwork.weights() - hypernetwork.train_mode() + hypernetwork.train() # Here we use optimizer from saved HN, or we can specify as UI option. if hypernetwork.optimizer_name in optimizer_dict: @@ -594,7 +638,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if images_dir is not None and steps_done % create_image_every == 0: forced_filename = f'{hypernetwork_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) - hypernetwork.eval_mode() + hypernetwork.eval() + rng_state = torch.get_rng_state() + cuda_rng_state = None + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state_all() shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) @@ -627,7 +675,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - hypernetwork.train_mode() + torch.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state_all(cuda_rng_state) + hypernetwork.train() if image is not None: shared.state.current_image = image last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) @@ -649,7 +700,7 @@ Last saved image: {html.escape(last_saved_image)}
finally: pbar.leave = False pbar.close() - hypernetwork.eval_mode() + hypernetwork.eval() #report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index e7f9e593..81e3f519 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,8 +9,8 @@ from modules import devices, sd_hijack, shared not_available = ["hardswish", "multiheadattention"] keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) -def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): - filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout) +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): + filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", "" diff --git a/modules/ui.py b/modules/ui.py index b6079aec..9b9081b5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1268,6 +1268,7 @@ def create_ui(): new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") with gr.Row(): @@ -1414,7 +1415,8 @@ def create_ui(): new_hypernetwork_activation_func, new_hypernetwork_initialization_option, new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout + new_hypernetwork_use_dropout, + new_hypernetwork_dropout_structure ], outputs=[ train_hypernetwork_name,