This commit is contained in:
James Betker 2021-09-23 16:07:58 -06:00
parent 6833048bf7
commit e24c619387
2 changed files with 5 additions and 11 deletions

View File

@ -251,11 +251,7 @@ class DiffusionDVAE(nn.Module):
)
def get_debug_values(self, step, __):
if self.record_codes:
# Report annealing schedule
return {'histogram_codes': self.codes}
else:
return {}
return {'histogram_codes': self.codes}
@torch.no_grad()
@eval_decorator
@ -379,10 +375,8 @@ class DiffusionDVAE(nn.Module):
# Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__':
spec = torch.randn(4, 80, 416)
cond = torch.randn(4, 5, 80, 200)
num_cond = torch.tensor([2,4,5,3], dtype=torch.long)
spec = torch.randn(4, 80, 161)
ts = torch.LongTensor([432, 234, 100, 555])
model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2],
channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False)
print(model(torch.randn_like(spec), ts, spec, cond, num_cond)[0].shape)
print(model(torch.randn_like(spec), ts, spec)[0].shape)

View File

@ -565,5 +565,5 @@ class RandomAudioCropInjector(Injector):
if __name__ == '__main__':
inj = DecomposeDimensionInjector({'dim':2, 'in': 'x', 'out': 'y'}, None)
print(inj({'x':torch.randn(10,3,64,64)})['y'].shape)
inj = MelSpectrogramInjector({'in': 'x', 'out': 'y'}, None)
print(inj({'x':torch.rand(10,1,40800)})['y'].shape)