This commit is contained in:
James Betker 2021-09-17 09:15:36 -06:00
parent a6544f1684
commit 5c8d266d4f
2 changed files with 8 additions and 2 deletions

View File

@ -122,6 +122,7 @@ class DiffusionDVAE(nn.Module):
linear(time_embed_dim, time_embed_dim),
)
self.conditioning_enabled = conditioning_inputs_provided
if conditioning_inputs_provided:
self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim)
self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim)
@ -249,6 +250,9 @@ class DiffusionDVAE(nn.Module):
)
def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals):
if self.conditioning_enabled:
assert conditioning_inputs is not None
spec_hs = self.decoder(embeddings)[::-1]
# Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.)
spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)]
@ -257,7 +261,7 @@ class DiffusionDVAE(nn.Module):
# Timestep embeddings and conditioning signals are combined using a small transformer.
hs = []
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if conditioning_inputs is not None:
if self.conditioning_enabled:
mask = get_mask_from_lengths(num_conditioning_signals+1, conditioning_inputs.shape[1]+1) # +1 to account for the timestep embeddings we'll add.
emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1)

View File

@ -117,7 +117,9 @@ class ExtensibleTrainer(BaseModel):
dnet = DistributedDataParallel(anet, delay_allreduce=True)
else:
from torch.nn.parallel.distributed import DistributedDataParallel
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
# Do NOT be tempted to put find_unused_parameters=True here. It will not work in the current incarnation of this trainer.
# Use all of your parameters in training, or delete them!
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()])
else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
if self.is_train: