Move silu to sd_hijack

This commit is contained in:
Jairo Correa 2022-09-29 01:14:54 -03:00
parent c938679de7
commit c2d5b29040
2 changed files with 3 additions and 12 deletions

View File

@ -12,6 +12,7 @@ from ldm.util import default
from einops import rearrange from einops import rearrange
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
from torch.nn.functional import silu
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
@ -100,14 +101,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2) return self.to_out(r2)
def nonlinearity_hijack(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def cross_attention_attnblock_forward(self, x): def cross_attention_attnblock_forward(self, x):
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
@ -245,11 +238,12 @@ class StableDiffusionModelHijack:
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.opt_split_attention_v1: if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
def flatten(el): def flatten(el):

View File

@ -22,10 +22,7 @@ import modules.txt2img
import modules.img2img import modules.img2img
import modules.swinir as swinir import modules.swinir as swinir
import modules.sd_models import modules.sd_models
from torch.nn.functional import silu
import ldm
ldm.modules.diffusionmodules.model.nonlinearity = silu
modules.codeformer_model.setup_codeformer() modules.codeformer_model.setup_codeformer()
modules.gfpgan_model.setup_gfpgan() modules.gfpgan_model.setup_gfpgan()