Merge pull request #4978 from aliencaocao/support_any_resolution
Patch UNet Forward to support resolutions that are not multiples of 64
This commit is contained in:
commit
2641d1b83b
|
@ -39,6 +39,7 @@ def apply_optimizations():
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
|
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward
|
||||||
|
|
||||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||||
print("Applying xformers cross attention optimization.")
|
print("Applying xformers cross attention optimization.")
|
||||||
|
|
|
@ -5,6 +5,7 @@ import importlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
@ -12,6 +13,8 @@ from einops import rearrange
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
|
|
||||||
|
|
||||||
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||||
try:
|
try:
|
||||||
|
@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x):
|
||||||
return x + out
|
return x + out
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return cross_attention_attnblock_forward(self, x)
|
return cross_attention_attnblock_forward(self, x)
|
||||||
|
|
||||||
|
def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||||
|
assert (y is not None) == (
|
||||||
|
self.num_classes is not None
|
||||||
|
), "must specify y if and only if the model is class-conditional"
|
||||||
|
hs = []
|
||||||
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||||
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
if self.num_classes is not None:
|
||||||
|
assert y.shape == (x.shape[0],)
|
||||||
|
emb = emb + self.label_emb(y)
|
||||||
|
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
for module in self.input_blocks:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
h = self.middle_block(h, emb, context)
|
||||||
|
for module in self.output_blocks:
|
||||||
|
if h.shape[-2:] != hs[-1].shape[-2:]:
|
||||||
|
h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
|
||||||
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
|
h = module(h, emb, context)
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
if self.predict_codebook_ids:
|
||||||
|
return self.id_predictor(h)
|
||||||
|
else:
|
||||||
|
return self.out(h)
|
||||||
|
|
|
@ -302,8 +302,8 @@ def create_seed_inputs():
|
||||||
|
|
||||||
with gr.Row(visible=False) as seed_extra_row_2:
|
with gr.Row(visible=False) as seed_extra_row_2:
|
||||||
seed_extras.append(seed_extra_row_2)
|
seed_extras.append(seed_extra_row_2)
|
||||||
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0)
|
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0)
|
||||||
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0)
|
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0)
|
||||||
|
|
||||||
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
|
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
|
||||||
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
|
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
|
||||||
|
@ -635,8 +635,8 @@ def create_ui():
|
||||||
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
|
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
|
||||||
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
|
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
|
||||||
|
@ -644,8 +644,8 @@ def create_ui():
|
||||||
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
|
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
|
||||||
|
|
||||||
with gr.Row(visible=False) as hr_options:
|
with gr.Row(visible=False) as hr_options:
|
||||||
firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
|
firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0)
|
||||||
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
|
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0)
|
||||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
|
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
|
||||||
|
|
||||||
with gr.Row(equal_height=True):
|
with gr.Row(equal_height=True):
|
||||||
|
@ -835,8 +835,8 @@ def create_ui():
|
||||||
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
|
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
|
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||||
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
|
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
|
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
|
||||||
|
@ -1171,8 +1171,8 @@ def create_ui():
|
||||||
with gr.Tab(label="Preprocess images"):
|
with gr.Tab(label="Preprocess images"):
|
||||||
process_src = gr.Textbox(label='Source directory')
|
process_src = gr.Textbox(label='Source directory')
|
||||||
process_dst = gr.Textbox(label='Destination directory')
|
process_dst = gr.Textbox(label='Destination directory')
|
||||||
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
|
||||||
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
|
||||||
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
|
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -1230,8 +1230,8 @@ def create_ui():
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||||
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
|
||||||
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user