Add norming to discretization_loss

This commit is contained in:
James Betker 2021-10-06 17:10:50 -06:00
parent bb891a3a53
commit 33120cb35c
4 changed files with 36 additions and 11 deletions

View File

@ -10,21 +10,46 @@ import torch.nn.functional as F
# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
# values with the specified expected variance.
class DiscretizationLoss(nn.Module):
def __init__(self, dim, expected_variance):
def __init__(self, discrete_bins, dim, expected_variance, store_past=0):
super().__init__()
self.discrete_bins = discrete_bins
self.dim = dim
self.dist = torch.distributions.Normal(0, scale=expected_variance)
if store_past > 0:
self.record_past = True
self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu'))
self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu'))
self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins))
else:
self.record_past = False
def forward(self, x):
other_dims = set(range(len(x.shape)))-set([self.dim])
averaged = x.sum(dim=tuple(other_dims)) / x.sum()
averaged = averaged - averaged.mean()
if self.record_past:
acc_count = self.accumulator.shape[0]
avg = averaged.detach().clone()
if self.accumulator_filled > 0:
averaged = torch.mean(self.accumulator, dim=0) * (acc_count-1) / acc_count + \
averaged / acc_count
# Also push averaged into the accumulator.
self.accumulator[self.accumulator_index] = avg
self.accumulator_index += 1
if self.accumulator_index >= acc_count:
self.accumulator_index *= 0
if self.accumulator_filled <= 0:
self.accumulator_filled += 1
return torch.sum(-self.dist.log_prob(averaged))
if __name__ == '__main__':
d = DiscretizationLoss(1, 1e-6)
v = torch.randn(16, 8192, 500)
d = DiscretizationLoss(1024, 1, 1e-6, store_past=20)
for _ in range(500):
v = torch.randn(16, 1024, 500)
#for k in range(5):
# v[:, random.randint(0,8192), :] += random.random()*100
v = F.softmax(v, 1)

View File

@ -75,6 +75,7 @@ class DiscreteVAE(nn.Module):
straight_through = False,
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
record_codes = False,
discretization_loss_averaging_steps = 100,
):
super().__init__()
assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
@ -85,7 +86,7 @@ class DiscreteVAE(nn.Module):
self.straight_through = straight_through
self.codebook = Quantize(codebook_dim, num_tokens)
self.positional_dims = positional_dims
self.discrete_loss = DiscretizationLoss(2, 1 / (num_tokens*2))
self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
if positional_dims == 2:
@ -224,7 +225,6 @@ class DiscreteVAE(nn.Module):
# discretization loss
disc_loss = self.discrete_loss(soft_codes)
# This is so we can debug the distribution of codes being learned.
if self.record_codes and self.internal_step % 50 == 0:
codes = codes.flatten()

View File

@ -106,7 +106,7 @@ class Quantize(nn.Module):
quantize = input + (quantize - input).detach()
if return_soft_codes:
return quantize, diff, embed_ind, soft_codes.view(input.shape)
return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
else:
return quantize, diff, embed_ind

View File

@ -284,7 +284,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_dvae_clips.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_dvae_clips_with_discretization_loss.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()