From c2d5b29040132c171bc4d77f1f63da972306f22c Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Thu, 29 Sep 2022 01:14:54 -0300 Subject: [PATCH] Move silu to sd_hijack --- modules/sd_hijack.py | 12 +++--------- webui.py | 3 --- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bfbd07f9..4bc58fa2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -12,6 +12,7 @@ from ldm.util import default from einops import rearrange import ldm.modules.attention import ldm.modules.diffusionmodules.model +from torch.nn.functional import silu # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion @@ -100,14 +101,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def nonlinearity_hijack(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -245,11 +238,12 @@ class StableDiffusionModelHijack: m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.clip = m.cond_stage_model + ldm.modules.diffusionmodules.model.nonlinearity = silu + if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward def flatten(el): diff --git a/webui.py b/webui.py index b61a318d..c70a11c7 100644 --- a/webui.py +++ b/webui.py @@ -22,10 +22,7 @@ import modules.txt2img import modules.img2img import modules.swinir as swinir import modules.sd_models -from torch.nn.functional import silu -import ldm -ldm.modules.diffusionmodules.model.nonlinearity = silu modules.codeformer_model.setup_codeformer() modules.gfpgan_model.setup_gfpgan()