combined dvae ftw
This commit is contained in:
parent
0237e96b34
commit
a7496b661c
|
@ -13,14 +13,14 @@ class GptTts(nn.Module):
|
|||
MAX_SYMBOLS_PER_PHRASE = 200
|
||||
NUMBER_SYMBOLS = len(symbols)
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
|
||||
MEL_DICTIONARY_SIZE = 1024+3
|
||||
MEL_DICTIONARY_SIZE = 512+3
|
||||
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
|
||||
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_dim = 512
|
||||
max_mel_frames = 900 * 3 // 8 # 900 is the max number of MEL frames. The VQVAE outputs 3/8 of the input mel as tokens.
|
||||
max_mel_frames = 900 * 1 // 4 # 900 is the max number of MEL frames. The VQVAE outputs 1/8 of the input mel as tokens.
|
||||
|
||||
self.model_dim = model_dim
|
||||
self.max_mel_frames = max_mel_frames
|
||||
|
|
|
@ -134,9 +134,10 @@ class DiscreteVAE(nn.Module):
|
|||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def get_codebook_indices(self, images):
|
||||
logits = self(images, return_logits = True)
|
||||
codebook_indices = logits.argmax(dim = 1).flatten(1)
|
||||
return codebook_indices
|
||||
img = self.norm(images)
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, commitment_loss, codes = self.codebook(logits)
|
||||
return codes
|
||||
|
||||
def decode(
|
||||
self,
|
||||
|
|
|
@ -54,7 +54,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from data.util import is_wav_file, get_image_paths
|
||||
from models.audio_resnet import resnet34
|
||||
from models.audio_resnet import resnet34, resnet50
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch
|
||||
from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict
|
||||
|
||||
|
@ -20,13 +20,13 @@ if __name__ == '__main__':
|
|||
clip = clip[:,0]
|
||||
clip = clip[:window].unsqueeze(0)
|
||||
clip = clip / 32768.0 # Normalize
|
||||
clip = clip + torch.rand_like(clip) * .03 # Noise (this is how the model was trained)
|
||||
#clip = clip + torch.rand_like(clip) * .03 # Noise (this is how the model was trained)
|
||||
assert sr == 24000
|
||||
clips.append(clip)
|
||||
clips = torch.stack(clips, dim=0)
|
||||
|
||||
resnet = resnet34()
|
||||
sd = torch.load('../experiments/train_byol_audio_clips/models/57000_generator.pth')
|
||||
resnet = resnet50()
|
||||
sd = torch.load('../experiments/train_byol_audio_clips/models/8000_generator.pth')
|
||||
sd = extract_byol_model_from_state_dict(sd)
|
||||
resnet.load_state_dict(sd)
|
||||
embedding = resnet(clips, return_pool=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user