import os import numpy as np import PIL import torch from PIL import Image from torch.utils.data import Dataset from torchvision import transforms import random import tqdm class PersonalizedBase(Dataset): def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None): self.placeholder_token = placeholder_token self.size = size self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] with open(template_file, "r") as file: lines = [x.strip() for x in file.readlines()] self.lines = lines assert data_root, 'dataset directory not specified' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): image = Image.open(path) image = image.convert('RGB') image = image.resize((self.width, self.height), PIL.Image.BICUBIC) filename = os.path.basename(path) filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-') filename_tokens = [token for token in filename_tokens if token.isalpha()] npimage = np.array(image).astype(np.uint8) npimage = (npimage / 127.5 - 1.0).astype(np.float32) torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) torchdata = torch.moveaxis(torchdata, 2, 0) init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() self.dataset.append((init_latent, filename_tokens)) self.length = len(self.dataset) * repeats self.initial_indexes = np.arange(self.length) % len(self.dataset) self.indexes = None self.shuffle() def shuffle(self): self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] def __len__(self): return self.length def __getitem__(self, i): if i % len(self.dataset) == 0: self.shuffle() index = self.indexes[i % len(self.indexes)] x, filename_tokens = self.dataset[index] text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) text = text.replace("[filewords]", ' '.join(filename_tokens)) return x, text