forked from mrq/DL-Art-School
Rework conditioning inputs provided
This commit is contained in:
parent
21b6daa0ed
commit
3a9d1c53ea
|
@ -283,7 +283,7 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
self.middle_block.apply(convert_module_to_f32)
|
self.middle_block.apply(convert_module_to_f32)
|
||||||
self.output_blocks.apply(convert_module_to_f32)
|
self.output_blocks.apply(convert_module_to_f32)
|
||||||
|
|
||||||
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None, num_conditioning_signals=None):
|
def forward(self, x, timesteps, spectrogram, conditioning_input=None):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
|
|
||||||
|
@ -294,14 +294,12 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
"""
|
"""
|
||||||
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
||||||
if self.conditioning_enabled:
|
if self.conditioning_enabled:
|
||||||
assert conditioning_inputs is not None
|
assert conditioning_input is not None
|
||||||
assert num_conditioning_signals is not None
|
|
||||||
|
|
||||||
hs = []
|
hs = []
|
||||||
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||||
if self.conditioning_enabled:
|
if self.conditioning_enabled:
|
||||||
#emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
|
emb2 = self.contextual_embedder(conditioning_input)
|
||||||
emb2 = self.contextual_embedder(conditioning_inputs[:, 0])
|
|
||||||
emb = emb1 + emb2
|
emb = emb1 + emb2
|
||||||
else:
|
else:
|
||||||
emb = emb1
|
emb = emb1
|
||||||
|
@ -331,7 +329,7 @@ if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 1, 40960)
|
clip = torch.randn(2, 1, 40960)
|
||||||
#spec = torch.randint(8192, (2, 40,))
|
#spec = torch.randint(8192, (2, 40,))
|
||||||
spec = torch.randn(2,512,160)
|
spec = torch.randn(2,512,160)
|
||||||
cond = torch.randn(2, 3, 80, 173)
|
cond = torch.randn(2, 80, 173)
|
||||||
ts = torch.LongTensor([555, 556])
|
ts = torch.LongTensor([555, 556])
|
||||||
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8)
|
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8)
|
||||||
print(model(clip, ts, spec, cond, 3).shape)
|
print(model(clip, ts, spec, cond).shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user