forked from mrq/DL-Art-School
Add time_embed_dim_multiplier
This commit is contained in:
parent
ba6e46c02a
commit
43e389aac6
|
@ -90,6 +90,7 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
scale_factor=2,
|
||||
conditioning_inputs_provided=True,
|
||||
conditioning_input_dim=80,
|
||||
time_embed_dim_multiplier=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -111,7 +112,7 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
|
@ -332,5 +333,5 @@ if __name__ == '__main__':
|
|||
spec = torch.randn(2,512,160)
|
||||
cond = torch.randn(2, 3, 80, 173)
|
||||
ts = torch.LongTensor([555, 556])
|
||||
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True)
|
||||
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8)
|
||||
print(model(clip, ts, spec, cond, 3).shape)
|
||||
|
|
Loading…
Reference in New Issue
Block a user