Fix
This commit is contained in:
parent
6833048bf7
commit
e24c619387
|
@ -251,11 +251,7 @@ class DiffusionDVAE(nn.Module):
|
|||
)
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
if self.record_codes:
|
||||
# Report annealing schedule
|
||||
return {'histogram_codes': self.codes}
|
||||
else:
|
||||
return {}
|
||||
return {'histogram_codes': self.codes}
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
|
@ -379,10 +375,8 @@ class DiffusionDVAE(nn.Module):
|
|||
|
||||
# Test for ~4 second audio clip at 22050Hz
|
||||
if __name__ == '__main__':
|
||||
spec = torch.randn(4, 80, 416)
|
||||
cond = torch.randn(4, 5, 80, 200)
|
||||
num_cond = torch.tensor([2,4,5,3], dtype=torch.long)
|
||||
spec = torch.randn(4, 80, 161)
|
||||
ts = torch.LongTensor([432, 234, 100, 555])
|
||||
model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2],
|
||||
channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False)
|
||||
print(model(torch.randn_like(spec), ts, spec, cond, num_cond)[0].shape)
|
||||
print(model(torch.randn_like(spec), ts, spec)[0].shape)
|
||||
|
|
|
@ -565,5 +565,5 @@ class RandomAudioCropInjector(Injector):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
inj = DecomposeDimensionInjector({'dim':2, 'in': 'x', 'out': 'y'}, None)
|
||||
print(inj({'x':torch.randn(10,3,64,64)})['y'].shape)
|
||||
inj = MelSpectrogramInjector({'in': 'x', 'out': 'y'}, None)
|
||||
print(inj({'x':torch.rand(10,1,40800)})['y'].shape)
|
Loading…
Reference in New Issue
Block a user