Move silu to sd_hijack
This commit is contained in:
parent
c938679de7
commit
c2d5b29040
|
@ -12,6 +12,7 @@ from ldm.util import default
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
|
from torch.nn.functional import silu
|
||||||
|
|
||||||
|
|
||||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||||
|
@ -100,14 +101,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
|
||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
def nonlinearity_hijack(x):
|
|
||||||
# swish
|
|
||||||
t = torch.sigmoid(x)
|
|
||||||
x *= t
|
|
||||||
del t
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def cross_attention_attnblock_forward(self, x):
|
def cross_attention_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
|
@ -245,11 +238,12 @@ class StableDiffusionModelHijack:
|
||||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
self.clip = m.cond_stage_model
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
|
|
||||||
if cmd_opts.opt_split_attention_v1:
|
if cmd_opts.opt_split_attention_v1:
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
|
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||||
|
|
||||||
def flatten(el):
|
def flatten(el):
|
||||||
|
|
3
webui.py
3
webui.py
|
@ -22,10 +22,7 @@ import modules.txt2img
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
import modules.swinir as swinir
|
import modules.swinir as swinir
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
from torch.nn.functional import silu
|
|
||||||
import ldm
|
|
||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
|
||||||
|
|
||||||
modules.codeformer_model.setup_codeformer()
|
modules.codeformer_model.setup_codeformer()
|
||||||
modules.gfpgan_model.setup_gfpgan()
|
modules.gfpgan_model.setup_gfpgan()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user