Undo baseline GDI changes

This commit is contained in:
James Betker 2021-11-18 20:02:09 -07:00
parent 1287915f3c
commit c30a38cdf1
2 changed files with 6 additions and 4 deletions

View File

@ -40,7 +40,7 @@ def forward_pass(model, data, output_dir, spacing, audio_mode):
for suffix in suffixes:
if audio_mode:
save_img_path = osp.join(output_dir, img_name + suffix + '.wav')
wavfile.write(osp.join(output_dir, save_img_path), 22050, sr_img[0].cpu().numpy())
wavfile.write(osp.join(output_dir, save_img_path), 11025, sr_img[0].cpu().numpy())
else:
save_img_path = osp.join(output_dir, img_name + suffix + '.png')
util.save_img(util.tensor2img(sr_img), save_img_path)
@ -50,7 +50,9 @@ def load_image(path, audio_mode):
# Load test image
if audio_mode:
im = load_audio(path, 22050)
im = im[:, :(im.shape[1]//4096)*4096].unsqueeze(0)
padding_needed = ((im.shape[1]//8192)+1)*8192-im.shape[1]
im = torch.nn.functional.pad(im, (0, padding_needed))
im = im[:, :(im.shape[1]//8192)*8192].unsqueeze(0)
else:
im = ToTensor()(Image.open(path)) * 2 - 1
_, h, w = im.shape
@ -80,7 +82,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder_10-28.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_diffusion_vocoder_dvae.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt

View File

@ -109,7 +109,7 @@ class GaussianDiffusionInferenceInjector(Injector):
model_inputs['low_res'].shape[-1] * self.output_scale_factor)
dev = model_inputs['low_res'].device
elif 'spectrogram' in model_inputs.keys():
output_shape = (self.output_batch_size, 1, model_inputs['spectrogram'].shape[-1]*256)
output_shape = (self.output_batch_size, 1, model_inputs['spectrogram'].shape[-1] * self.output_scale_factor)
dev = model_inputs['spectrogram'].device
elif 'discrete_spectrogram' in model_inputs.keys():
output_shape = (self.output_batch_size, 1, model_inputs['discrete_spectrogram'].shape[-1]*1024)