forked from mrq/DL-Art-School
Update diffusion_noise_surfer to support audio
This commit is contained in:
parent
3e073cff85
commit
92e7e57f81
|
@ -8,11 +8,13 @@ from collections import OrderedDict
|
|||
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from scipy.io import wavfile
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from data.audio.wavfile_dataset import load_audio_from_wav
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
|
@ -21,7 +23,7 @@ import numpy as np
|
|||
|
||||
# A rough copy of test.py that "surfs" along a set of random noise priors to show the affect of gaussian noise on the results.
|
||||
|
||||
def forward_pass(model, data, output_dir, spacing):
|
||||
def forward_pass(model, data, output_dir, spacing, audio_mode):
|
||||
with torch.no_grad():
|
||||
model.feed_data(data, 0)
|
||||
model.test()
|
||||
|
@ -29,13 +31,17 @@ def forward_pass(model, data, output_dir, spacing):
|
|||
visuals = model.get_current_visuals()['rlt'].cpu()
|
||||
img_path = data['GT_path'][0]
|
||||
img_name = osp.splitext(osp.basename(img_path))[0]
|
||||
sr_img = util.tensor2img(visuals[0]) # uint8
|
||||
sr_img = visuals[0]
|
||||
|
||||
# save images
|
||||
suffixes = [f'_{int(spacing)}']
|
||||
for suffix in suffixes:
|
||||
save_img_path = osp.join(output_dir, img_name + suffix + '.png')
|
||||
util.save_img(sr_img, save_img_path)
|
||||
if audio_mode:
|
||||
save_img_path = osp.join(output_dir, img_name + suffix + '.wav')
|
||||
wavfile.write(osp.join(output_dir, save_img_path), 22050, sr_img[0].cpu().numpy())
|
||||
else:
|
||||
save_img_path = osp.join(output_dir, img_name + suffix + '.png')
|
||||
util.save_img(util.tensor2img(sr_img), save_img_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -45,10 +51,11 @@ if __name__ == "__main__":
|
|||
np.random.seed(5555)
|
||||
|
||||
#### options
|
||||
audio_mode = True # Whether to render audio or images.
|
||||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_unet.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
@ -62,33 +69,45 @@ if __name__ == "__main__":
|
|||
logger.info(option.dict2str(opt))
|
||||
|
||||
# Load test image
|
||||
im = ToTensor()(Image.open(opt['image'])) * 2 - 1
|
||||
_, h, w = im.shape
|
||||
if h % 2 == 1:
|
||||
im = im[:,1:,:]
|
||||
h = h-1
|
||||
if w % 2 == 1:
|
||||
im = im[:,:,1:]
|
||||
w = w-1
|
||||
dh, dw = (h - 32 * (h // 32)) // 2, (w - 32 * (w // 32)) // 2
|
||||
if dh > 0:
|
||||
im = im[:,dh:-dh]
|
||||
if dw > 0:
|
||||
im = im[:,:,dw:-dw]
|
||||
im = im[:3].unsqueeze(0)
|
||||
if audio_mode:
|
||||
im = load_audio_from_wav(opt['image'], opt['sample_rate'])
|
||||
im = im[:, :(im.shape[1]//4096)*4096]
|
||||
# Hack to reduce memory usage (but cuts off sample):
|
||||
im = im[:, :40960]
|
||||
else:
|
||||
im = ToTensor()(Image.open(opt['image'])) * 2 - 1
|
||||
_, h, w = im.shape
|
||||
if h % 2 == 1:
|
||||
im = im[:,1:,:]
|
||||
h = h-1
|
||||
if w % 2 == 1:
|
||||
im = im[:,:,1:]
|
||||
w = w-1
|
||||
dh, dw = (h - 32 * (h // 32)) // 2, (w - 32 * (w // 32)) // 2
|
||||
if dh > 0:
|
||||
im = im[:,dh:-dh]
|
||||
if dw > 0:
|
||||
im = im[:,:,dw:-dw]
|
||||
im = im[:3].unsqueeze(0)
|
||||
|
||||
# Build the corruption indexes we are going to use.
|
||||
correction_factors = opt['correction_factor']
|
||||
# Build the corruption indexes we are going to use.
|
||||
correction_factors = opt['correction_factor']
|
||||
|
||||
opt['steps']['generator']['injectors']['visual_debug']['zero_noise'] = False
|
||||
#opt['steps']['generator']['injectors']['visual_debug']['zero_noise'] = False
|
||||
model = ExtensibleTrainer(opt)
|
||||
results_dir = osp.join(opt['path']['results_root'], os.path.basename(opt['image']))
|
||||
util.mkdir(results_dir)
|
||||
for i in range(10):
|
||||
data = {
|
||||
'hq': im.to('cuda'),
|
||||
'corruption_entropy': torch.tensor([correction_factors], device='cuda',
|
||||
dtype=torch.float),
|
||||
'GT_path': opt['image']
|
||||
}
|
||||
forward_pass(model, data, results_dir, i)
|
||||
if audio_mode:
|
||||
data = {
|
||||
'clip': im.to('cuda'),
|
||||
'GT_path': opt['image']
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
'hq': im.to('cuda'),
|
||||
'corruption_entropy': torch.tensor([correction_factors], device='cuda',
|
||||
dtype=torch.float),
|
||||
'GT_path': opt['image']
|
||||
}
|
||||
forward_pass(model, data, results_dir, i, audio_mode)
|
||||
|
|
|
@ -97,14 +97,21 @@ class GaussianDiffusionInferenceInjector(Injector):
|
|||
model_inputs = {k: state[v][:self.output_batch_size] for k, v in self.model_input_keys.items()}
|
||||
gen.eval()
|
||||
with torch.no_grad():
|
||||
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
|
||||
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
|
||||
if 'low_res' in model_inputs.keys():
|
||||
output_shape = (self.output_batch_size, 3, model_inputs['low_res'].shape[-2] * self.output_scale_factor,
|
||||
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
|
||||
dev = model_inputs['low_res'].device
|
||||
elif 'spectrogram' in model_inputs.keys():
|
||||
output_shape = (self.output_batch_size, 1, model_inputs['spectrogram'].shape[-1]*256)
|
||||
dev = model_inputs['spectrogram'].device
|
||||
else:
|
||||
raise NotImplementedError
|
||||
noise = None
|
||||
if self.noise_style == 'zero':
|
||||
noise = torch.zeros(output_shape, device=model_inputs['low_res'].device)
|
||||
noise = torch.zeros(output_shape, device=dev)
|
||||
elif self.noise_style == 'fixed':
|
||||
if not hasattr(self, 'fixed_noise') or self.fixed_noise.shape != output_shape:
|
||||
self.fixed_noise = torch.randn(output_shape, device=model_inputs['low_res'].device)
|
||||
self.fixed_noise = torch.randn(output_shape, device=dev)
|
||||
noise = self.fixed_noise
|
||||
gen = self.sampling_fn(gen, output_shape, noise=noise, model_kwargs=model_inputs, progress=True)
|
||||
if self.undo_n1_to_1:
|
||||
|
|
Loading…
Reference in New Issue
Block a user