This commit is contained in:
James Betker 2022-07-19 11:14:20 -06:00
parent da9e47ca0e
commit 6b1cfe8e66

View File

@ -174,7 +174,7 @@ class TransformerDiffusion(nn.Module):
assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}'
resolution = randrange(1, self.resolution_steps)
resolution_scale = 2 ** resolution
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='nearest', align_corners=True)
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
s_diff = s.shape[-1] - self.max_window
if s_diff > 1:
@ -265,5 +265,5 @@ def remove_conditioning(sd_path):
if __name__ == '__main__':
remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\12000_generator.pth')
remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr_pre\\models\\12500_generator.pth')
test_tfd()