combined dvae ftw

This commit is contained in:
James Betker 2021-08-06 22:01:06 -06:00
parent 0237e96b34
commit a7496b661c
4 changed files with 11 additions and 10 deletions

View File

@ -13,14 +13,14 @@ class GptTts(nn.Module):
MAX_SYMBOLS_PER_PHRASE = 200
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
MEL_DICTIONARY_SIZE = 1024+3
MEL_DICTIONARY_SIZE = 512+3
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
def __init__(self):
super().__init__()
model_dim = 512
max_mel_frames = 900 * 3 // 8 # 900 is the max number of MEL frames. The VQVAE outputs 3/8 of the input mel as tokens.
max_mel_frames = 900 * 1 // 4 # 900 is the max number of MEL frames. The VQVAE outputs 1/8 of the input mel as tokens.
self.model_dim = model_dim
self.max_mel_frames = max_mel_frames

View File

@ -134,9 +134,10 @@ class DiscreteVAE(nn.Module):
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images):
logits = self(images, return_logits = True)
codebook_indices = logits.argmax(dim = 1).flatten(1)
return codebook_indices
img = self.norm(images)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, commitment_loss, codes = self.codebook(logits)
return codes
def decode(
self,

View File

@ -54,7 +54,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from data.util import is_wav_file, get_image_paths
from models.audio_resnet import resnet34
from models.audio_resnet import resnet34, resnet50
from models.tacotron2.taco_utils import load_wav_to_torch
from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict
@ -20,13 +20,13 @@ if __name__ == '__main__':
clip = clip[:,0]
clip = clip[:window].unsqueeze(0)
clip = clip / 32768.0 # Normalize
clip = clip + torch.rand_like(clip) * .03 # Noise (this is how the model was trained)
#clip = clip + torch.rand_like(clip) * .03 # Noise (this is how the model was trained)
assert sr == 24000
clips.append(clip)
clips = torch.stack(clips, dim=0)
resnet = resnet34()
sd = torch.load('../experiments/train_byol_audio_clips/models/57000_generator.pth')
resnet = resnet50()
sd = torch.load('../experiments/train_byol_audio_clips/models/8000_generator.pth')
sd = extract_byol_model_from_state_dict(sd)
resnet.load_state_dict(sd)
embedding = resnet(clips, return_pool=True)