diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index eb0b56ad..c76ee65e 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -13,14 +13,14 @@ class GptTts(nn.Module): MAX_SYMBOLS_PER_PHRASE = 200 NUMBER_SYMBOLS = len(symbols) NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2 - MEL_DICTIONARY_SIZE = 1024+3 + MEL_DICTIONARY_SIZE = 512+3 MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3 MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2 def __init__(self): super().__init__() model_dim = 512 - max_mel_frames = 900 * 3 // 8 # 900 is the max number of MEL frames. The VQVAE outputs 3/8 of the input mel as tokens. + max_mel_frames = 900 * 1 // 4 # 900 is the max number of MEL frames. The VQVAE outputs 1/8 of the input mel as tokens. self.model_dim = model_dim self.max_mel_frames = max_mel_frames diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index cd5c25cc..d297fa00 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -134,9 +134,10 @@ class DiscreteVAE(nn.Module): @torch.no_grad() @eval_decorator def get_codebook_indices(self, images): - logits = self(images, return_logits = True) - codebook_indices = logits.argmax(dim = 1).flatten(1) - return codebook_indices + img = self.norm(images) + logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) + sampled, commitment_loss, codes = self.codebook(logits) + return codes def decode( self, diff --git a/codes/scripts/audio/test_audio_gen.py b/codes/scripts/audio/test_audio_gen.py index 32ab641f..783564cc 100644 --- a/codes/scripts/audio/test_audio_gen.py +++ b/codes/scripts/audio/test_audio_gen.py @@ -54,7 +54,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt diff --git a/codes/scripts/audio/test_audio_similarity.py b/codes/scripts/audio/test_audio_similarity.py index 2ffaeb3d..a7afc5bf 100644 --- a/codes/scripts/audio/test_audio_similarity.py +++ b/codes/scripts/audio/test_audio_similarity.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from data.util import is_wav_file, get_image_paths -from models.audio_resnet import resnet34 +from models.audio_resnet import resnet34, resnet50 from models.tacotron2.taco_utils import load_wav_to_torch from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict @@ -20,13 +20,13 @@ if __name__ == '__main__': clip = clip[:,0] clip = clip[:window].unsqueeze(0) clip = clip / 32768.0 # Normalize - clip = clip + torch.rand_like(clip) * .03 # Noise (this is how the model was trained) + #clip = clip + torch.rand_like(clip) * .03 # Noise (this is how the model was trained) assert sr == 24000 clips.append(clip) clips = torch.stack(clips, dim=0) - resnet = resnet34() - sd = torch.load('../experiments/train_byol_audio_clips/models/57000_generator.pth') + resnet = resnet50() + sd = torch.load('../experiments/train_byol_audio_clips/models/8000_generator.pth') sd = extract_byol_model_from_state_dict(sd) resnet.load_state_dict(sd) embedding = resnet(clips, return_pool=True)