stable-diffusion-utils/src/preprocess.py

219 lines
7.5 KiB
Python

# Credits to https://gist.github.com/nnuudev/56ed3242023c8582a32e3130ef59730b / https://boards.4chan.org/trash/thread/51463059#p51472156
import os
import re
import json
import time
import shutil
import math
import urllib.request
config = {
'source': "./data/config/preprocess.json",
'input': './images/downloaded/', # files to process
'output': './images/tagged/', # files to copy files to
'tags': './data/tags.csv', # csv of tags associated with the yiffy model (replace for other flavor of booru's taglist associated with the model you're training against)
'cache': './data/cache.json', # JSON file of cached tags, will speed up processing if re-running
'rateLimit': 500, # time to wait between requests, in milliseconds, e621 imposes a rate limit of 2 requests per second
'filenameLimit': 240, # maximum characters to put in the filename, necessary to abide by filesystem limitations
# you can set this to 245, as the web UI has uncapped the prompt limit, but I have yet to test this if this limit was also lifted for textual inversion
'filter': True,
# fill it with tags of whatever you don't want to make it into the filename
# for starters, you can also add "anthro", "male", "female", as they're very common tags
'filters': [
# "anthro",
# "fur",
# "male",
# "female",
"animal genitalia",
"genitals",
"video games",
],
'filtersRegex': [
r"clothing$",
r"fluids$",
r" (fe)?male$",
],
# treat these tags as already being included in the
# if you're cautious (paranoid), include species you want, but I found I don't really even need to include specis
# you can also include character names / series names if you're using this for hypernetworks
'tagsOverride': [],
# 'tagsOverride': ["character", "species", "copyright"], # useful for hypernetwork training
'tagsOverrideStart': 1000000, # starting score that your overriden tags will start from, for sorting purposes
# tags to always include in the list
# I HIGHLY suggest including these tags in your training template instead
'tagsAutoInclude': [],
'removeParentheses': True, # removes shit like `blaidd_(elden_ring)` or `curt_(animal_crossing)` without needing to specify it all in the above
# good because it messes with a lot of shit
'onlyIncludeModelArtists': True, # if True, only include the artist's tag if in the model's taglist, if false, add all artists
# i've noticed some artists that weren't included in the taglist, but is available in LAION's (vanilla SD)
'reverseTags': False, # inverts sorting, prioritizing tags with little representation in the model
'tagDelimiter': ",", # what separates each tag in the filename, web UI will accept comma separated filenames
'invalidCharacters': "\\/:*?\"<>|", # characters that can't go in a filename
'lora': True, # set to true to enable outputting for LoRA training
}
if os.path.exists(config['source']):
try:
with open(config['source'], 'rb') as f:
imp = json.loads(f.read().decode('utf-8'))
for k in imp:
config[k] = imp[k]
print(f"Imported settings from {config['source']}")
except:
pass
with open(config['tags'], 'rb') as f:
csv = f.read().decode('utf-8').split("\n")
config['tags'] = {}
for i in csv:
k, v = i.split(',')
config['tags'][k] = int(v)
for i in range(len(config['tagsOverride'])):
override = config['tagsOverride'][i].replace("_", " ")
config['tags'][override] = config['tagsOverrideStart']
config['tagsOverrideStart'] = config['tagsOverrideStart'] - 1
cache = {}
try:
with open(config['cache'], 'rb') as f:
cache = json.loads(f.read().decode('utf-8'))
except:
pass
if config['lora']:
config['filenameLimit'] = 0
if len(config['tagDelimiter']) == 1:
config['tagDelimiter'] = config['tagDelimiter'] + " ";
def parse():
global config, cache
files = []
for file in os.listdir(config['input']):
files.append(file)
for i in range(len(files)):
index = i
file = files[i]
# try filenames like "83737b5e961b594c26e8feaed301e7a5 (1).jpg" (duplicated copies from a file manager)
md5 = re.match(r"^([a-f0-9]{32})", file)
if not md5:
# try filenames like "00001-83737b5e961b594c26e8feaed301e7a5.jpg" (output from voldy's web UI preprocessing)
md5 = re.match(r"([a-f0-9]{32})\.(jpe?g|png)$", file)
if not md5:
continue
md5 = md5.group(1)
print(f"[{(100.0 * i / len(files)):3.0f}%]: {md5}")
rateLimit = False
if not md5 in cache:
rateLimit = True
with urllib.request.urlopen(urllib.request.Request(f"https://e621.net/posts.json?tags=md5:{md5}",
headers = {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36'
}
)) as r:
j = json.loads(r.read())
cache[md5] = j["posts"][0] if j["posts"] else {}
json_meta = cache[md5]
if not json_meta:
continue
tags = config['tagsAutoInclude'].copy()
artist = ""
content = {
"s": "safe content",
"q": "questionable content",
"e": "explict content",
}.get(json_meta["rating"], "")
for cat in json_meta["tags"]:
override = cat in config['tagsOverride']
if cat == "artist":
tag = "by " + " and ".join(json_meta["tags"]["artist"])
if config['onlyIncludeModelArtists'] and not tag in config['tags']:
continue
artist = tag
else:
for tag in json_meta["tags"][cat]:
tag = tag.replace("_", " ")
if not override:
override = tag in config['tagsOverride']
if override:
if tag not in config['tags']:
idx = config['tagsOverride'].index( cat )
if idx >= 0:
scale = math.pow(10, len(config['tagsOverride']) - idx + 1)
else:
scale = 1
config['tags'][tag] = config['tagsOverrideStart'] * scale
config['tagsOverrideStart'] = config['tagsOverrideStart'] - 1
elif tag not in config['tags']:
continue
filtered = False
for char in config['invalidCharacters']: # illegal filename character
if char in tag:
filtered = True
break
if config['filter']:
if tag in config['filters']:
continue # was break in the original script, fixed ;)
for filter in config['filtersRegex']:
if re.search(filter, tag):
filtered = True
break
if filtered:
continue
if not filtered:
tags.append(tag)
tags.sort(key=lambda x: -config['tags'][x], reverse=config['reverseTags'])
if artist:
tags.insert(0, artist)
if content:
tags.insert(0, content)
jointmp = ""
filtered = []
for i in tags:
if config['filenameLimit'] > 0 and len(jointmp + config['tagDelimiter'] + i) > config['filenameLimit']:
break
jointmp += config['tagDelimiter'] + i
if config['removeParentheses']:
i = re.sub(r"\(.+?\)$", "", i).strip()
filtered.append(i)
joined = config['tagDelimiter'].join(filtered)
if config['lora']:
shutil.copy(os.path.join(config['input'], file), os.path.join(config['output'], file.replace(md5, f'{index}').strip()))
with open(os.path.join(config['output'], f"{index}.txt"), 'wb') as f:
f.write(joined.encode('utf-8'))
else:
shutil.copy(os.path.join(config['input'], file), os.path.join(config['output'], file.replace(md5, joined).strip()))
if rateLimit and config['rateLimit']:
time.sleep(config['rateLimit'] / 1000.0)
# NOOOOOOOO YOU'RE WASTING SPACE BY PRETTIFYING
with open(config['cache'], 'wb') as f:
f.write(json.dumps(cache, indent='\t').encode('utf-8'))
if __name__ == "__main__":
parse()