import os import torch import torch.nn as nn import torch.nn.functional as F from data.util import is_wav_file, get_image_paths from models.audio_resnet import resnet34 from models.tacotron2.taco_utils import load_wav_to_torch from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict if __name__ == '__main__': window = 48000 root_path = 'D:\\tmp\\clips' paths = get_image_paths('img', root_path, qualifier=is_wav_file)[0] clips = [] for path in paths: clip, sr = load_wav_to_torch(os.path.join(root_path, path)) if len(clip.shape) > 1: clip = clip[:,0] clip = clip[:window].unsqueeze(0) clip = clip / 32768.0 # Normalize assert sr == 24000 clips.append(clip) clips = torch.stack(clips, dim=0) resnet = resnet34() sd = torch.load('../experiments/train_byol_audio_clips/models/66000_generator.pth') sd = extract_byol_model_from_state_dict(sd) resnet.load_state_dict(sd) embedding = resnet(clips, return_pool=True) for i, path in enumerate(paths): print(f'Using a baseline of {path}..') for j, cpath in enumerate(paths): if i == j: continue l2 = F.mse_loss(embedding[j], embedding[i]) print(f'Compared to {cpath}: {l2}')