forked from mrq/DL-Art-School
chk
This commit is contained in:
parent
a6544f1684
commit
5c8d266d4f
|
@ -122,6 +122,7 @@ class DiffusionDVAE(nn.Module):
|
|||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.conditioning_enabled = conditioning_inputs_provided
|
||||
if conditioning_inputs_provided:
|
||||
self.contextual_embedder = AudioMiniEncoder(self.spectrogram_channels, time_embed_dim)
|
||||
self.query_gen = AudioMiniEncoder(decoder_channels[0], time_embed_dim)
|
||||
|
@ -249,6 +250,9 @@ class DiffusionDVAE(nn.Module):
|
|||
)
|
||||
|
||||
def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals):
|
||||
if self.conditioning_enabled:
|
||||
assert conditioning_inputs is not None
|
||||
|
||||
spec_hs = self.decoder(embeddings)[::-1]
|
||||
# Shape the spectrogram correctly. There is no guarantee it fits (though I probably should add an assertion here to make sure the resizing isn't too wacky.)
|
||||
spec_hs = [nn.functional.interpolate(sh, size=(x.shape[-1]//self.scale_steps**self.spectrogram_conditioning_levels[i],), mode='nearest') for i, sh in enumerate(spec_hs)]
|
||||
|
@ -257,7 +261,7 @@ class DiffusionDVAE(nn.Module):
|
|||
# Timestep embeddings and conditioning signals are combined using a small transformer.
|
||||
hs = []
|
||||
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
if conditioning_inputs is not None:
|
||||
if self.conditioning_enabled:
|
||||
mask = get_mask_from_lengths(num_conditioning_signals+1, conditioning_inputs.shape[1]+1) # +1 to account for the timestep embeddings we'll add.
|
||||
emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1)
|
||||
emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1)
|
||||
|
|
|
@ -117,7 +117,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||
else:
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
|
||||
# Do NOT be tempted to put find_unused_parameters=True here. It will not work in the current incarnation of this trainer.
|
||||
# Use all of your parameters in training, or delete them!
|
||||
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()])
|
||||
else:
|
||||
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
||||
if self.is_train:
|
||||
|
|
Loading…
Reference in New Issue
Block a user