diff --git a/codes/models/gpt_voice/dvae_arch_playground/discretization_loss.py b/codes/models/gpt_voice/dvae_arch_playground/discretization_loss.py index 40b0cb7a..02111035 100644 --- a/codes/models/gpt_voice/dvae_arch_playground/discretization_loss.py +++ b/codes/models/gpt_voice/dvae_arch_playground/discretization_loss.py @@ -10,22 +10,47 @@ import torch.nn.functional as F # In other words, attempts to force the discretization function to have a mean equal utilization across all discrete # values with the specified expected variance. class DiscretizationLoss(nn.Module): - def __init__(self, dim, expected_variance): + def __init__(self, discrete_bins, dim, expected_variance, store_past=0): super().__init__() + self.discrete_bins = discrete_bins self.dim = dim self.dist = torch.distributions.Normal(0, scale=expected_variance) + if store_past > 0: + self.record_past = True + self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu')) + self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu')) + self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins)) + else: + self.record_past = False def forward(self, x): other_dims = set(range(len(x.shape)))-set([self.dim]) averaged = x.sum(dim=tuple(other_dims)) / x.sum() averaged = averaged - averaged.mean() + + if self.record_past: + acc_count = self.accumulator.shape[0] + avg = averaged.detach().clone() + if self.accumulator_filled > 0: + averaged = torch.mean(self.accumulator, dim=0) * (acc_count-1) / acc_count + \ + averaged / acc_count + + # Also push averaged into the accumulator. + self.accumulator[self.accumulator_index] = avg + self.accumulator_index += 1 + if self.accumulator_index >= acc_count: + self.accumulator_index *= 0 + if self.accumulator_filled <= 0: + self.accumulator_filled += 1 + return torch.sum(-self.dist.log_prob(averaged)) if __name__ == '__main__': - d = DiscretizationLoss(1, 1e-6) - v = torch.randn(16, 8192, 500) - #for k in range(5): - # v[:, random.randint(0,8192), :] += random.random()*100 - v = F.softmax(v, 1) - print(d(v)) \ No newline at end of file + d = DiscretizationLoss(1024, 1, 1e-6, store_past=20) + for _ in range(500): + v = torch.randn(16, 1024, 500) + #for k in range(5): + # v[:, random.randint(0,8192), :] += random.random()*100 + v = F.softmax(v, 1) + print(d(v)) \ No newline at end of file diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 75ce5aba..04ca00a6 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -75,6 +75,7 @@ class DiscreteVAE(nn.Module): straight_through = False, normalization = None, # ((0.5,) * 3, (0.5,) * 3), record_codes = False, + discretization_loss_averaging_steps = 100, ): super().__init__() assert num_layers >= 1, 'number of layers must be greater than or equal to 1' @@ -85,7 +86,7 @@ class DiscreteVAE(nn.Module): self.straight_through = straight_through self.codebook = Quantize(codebook_dim, num_tokens) self.positional_dims = positional_dims - self.discrete_loss = DiscretizationLoss(2, 1 / (num_tokens*2)) + self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps) assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now. if positional_dims == 2: @@ -224,7 +225,6 @@ class DiscreteVAE(nn.Module): # discretization loss disc_loss = self.discrete_loss(soft_codes) - # This is so we can debug the distribution of codes being learned. if self.record_codes and self.internal_step % 50 == 0: codes = codes.flatten() diff --git a/codes/models/vqvae/vqvae.py b/codes/models/vqvae/vqvae.py index e90a7edb..7d735f9e 100644 --- a/codes/models/vqvae/vqvae.py +++ b/codes/models/vqvae/vqvae.py @@ -106,7 +106,7 @@ class Quantize(nn.Module): quantize = input + (quantize - input).detach() if return_soft_codes: - return quantize, diff, embed_ind, soft_codes.view(input.shape) + return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,)) else: return quantize, diff, embed_ind diff --git a/codes/train.py b/codes/train.py index d7a74abe..193cc2fc 100644 --- a/codes/train.py +++ b/codes/train.py @@ -284,7 +284,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_dvae_clips.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_dvae_clips_with_discretization_loss.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()