Merge pull request #5065 from JaySmithWpg/vram-leak
#3449 - VRAM leak when switching to/from inpainting checkpoint
This commit is contained in:
commit
01f2ed6844
|
@ -1,4 +1,4 @@
|
||||||
from collections import namedtuple
|
from collections import namedtuple, deque
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from math import floor
|
from math import floor
|
||||||
import torch
|
import torch
|
||||||
|
@ -344,18 +344,28 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class TorchHijack:
|
class TorchHijack:
|
||||||
def __init__(self, kdiff_sampler):
|
def __init__(self, sampler_noises):
|
||||||
self.kdiff_sampler = kdiff_sampler
|
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||||
|
# implementation.
|
||||||
|
self.sampler_noises = deque(sampler_noises)
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
if item == 'randn_like':
|
if item == 'randn_like':
|
||||||
return self.kdiff_sampler.randn_like
|
return self.randn_like
|
||||||
|
|
||||||
if hasattr(torch, item):
|
if hasattr(torch, item):
|
||||||
return getattr(torch, item)
|
return getattr(torch, item)
|
||||||
|
|
||||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||||
|
|
||||||
|
def randn_like(self, x):
|
||||||
|
if self.sampler_noises:
|
||||||
|
noise = self.sampler_noises.popleft()
|
||||||
|
if noise.shape == x.shape:
|
||||||
|
return noise
|
||||||
|
|
||||||
|
return torch.randn_like(x)
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
class KDiffusionSampler:
|
||||||
def __init__(self, funcname, sd_model):
|
def __init__(self, funcname, sd_model):
|
||||||
|
@ -367,7 +377,6 @@ class KDiffusionSampler:
|
||||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.sampler_noise_index = 0
|
|
||||||
self.stop_at = None
|
self.stop_at = None
|
||||||
self.eta = None
|
self.eta = None
|
||||||
self.default_eta = 1.0
|
self.default_eta = 1.0
|
||||||
|
@ -400,26 +409,14 @@ class KDiffusionSampler:
|
||||||
def number_of_needed_noises(self, p):
|
def number_of_needed_noises(self, p):
|
||||||
return p.steps
|
return p.steps
|
||||||
|
|
||||||
def randn_like(self, x):
|
|
||||||
noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
|
|
||||||
|
|
||||||
if noise is not None and x.shape == noise.shape:
|
|
||||||
res = noise
|
|
||||||
else:
|
|
||||||
res = torch.randn_like(x)
|
|
||||||
|
|
||||||
self.sampler_noise_index += 1
|
|
||||||
return res
|
|
||||||
|
|
||||||
def initialize(self, p):
|
def initialize(self, p):
|
||||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
self.model_wrap.step = 0
|
self.model_wrap.step = 0
|
||||||
self.sampler_noise_index = 0
|
|
||||||
self.eta = p.eta or opts.eta_ancestral
|
self.eta = p.eta or opts.eta_ancestral
|
||||||
|
|
||||||
if self.sampler_noises is not None:
|
if self.sampler_noises is not None:
|
||||||
k_diffusion.sampling.torch = TorchHijack(self)
|
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)
|
||||||
|
|
||||||
extra_params_kwargs = {}
|
extra_params_kwargs = {}
|
||||||
for param_name in self.extra_params:
|
for param_name in self.extra_params:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user