import os import numpy as np import PIL import torch from PIL import Image from torch.utils.data import Dataset from torchvision import transforms import random import tqdm from modules import devices import re re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): self.placeholder_token = placeholder_token self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.extns = [".jpg",".jpeg",".png"] self.dataset = [] with open(template_file, "r") as file: lines = [x.strip() for x in file.readlines()] self.lines = lines assert data_root, 'dataset directory not specified' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in self.extns] print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): image = Image.open(path) image = image.convert('RGB') image = image.resize((self.width, self.height), PIL.Image.BICUBIC) filename = os.path.basename(path) filename_tokens = os.path.splitext(filename)[0] filename_tokens = re_tag.findall(filename_tokens) npimage = np.array(image).astype(np.uint8) npimage = (npimage / 127.5 - 1.0).astype(np.float32) torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) torchdata = torch.moveaxis(torchdata, 2, 0) init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = init_latent.to(devices.cpu) self.dataset.append((init_latent, filename_tokens)) self.length = len(self.dataset) * repeats self.initial_indexes = np.arange(self.length) % len(self.dataset) self.indexes = None self.shuffle() def shuffle(self): self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] def __len__(self): return self.length def __getitem__(self, i): if i % len(self.dataset) == 0: self.shuffle() index = self.indexes[i % len(self.indexes)] x, filename_tokens = self.dataset[index] text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) text = text.replace("[filewords]", ' '.join(filename_tokens)) return x, text