actually for real fixed incrementing filenames because i had a regex that actually only worked if candidates or lines>1, cuda now takes priority over dml if you're a nut with both of them installed because you can just specify an override anyways

This commit is contained in:
mrq 2023-02-16 01:06:32 +00:00
parent ec80ca632b
commit eca61af016
2 changed files with 10 additions and 20 deletions

View File

@ -23,10 +23,10 @@ def get_device_name():
name = 'cpu' name = 'cpu'
if has_dml(): if torch.cuda.is_available():
name = 'dml'
elif torch.cuda.is_available():
name = 'cuda' name = 'cuda'
elif has_dml():
name = 'dml'
return name return name
@ -67,8 +67,7 @@ def get_device_batch_size():
return 4 return 4
return 1 return 1
def get_device_count(): def get_device_count(name=get_device_name()):
name = get_device_name()
if name == "cuda": if name == "cuda":
return torch.cuda.device_count() return torch.cuda.device_count()
if name == "dml": if name == "dml":

View File

@ -155,12 +155,10 @@ def generate(
idx_cache = {} idx_cache = {}
for i, file in enumerate(os.listdir(outdir)): for i, file in enumerate(os.listdir(outdir)):
filename = os.path.basename(file) filename = os.path.basename(file)
if filename[-5:] == ".json": extension = os.path.splitext(filename)[1]
match = re.findall(rf"^{voice}_(\d+)(?:.+?)\.json$", filename) if extension != ".json" and extension != ".wav":
elif filename[-4:] == ".wav":
match = re.findall(rf"^{voice}_(\d+)(?:.+?)\.wav$", filename)
else:
continue continue
match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename)
key = int(match[0]) key = int(match[0])
idx_cache[key] = True idx_cache[key] = True
@ -169,18 +167,11 @@ def generate(
keys = sorted(list(idx_cache.keys())) keys = sorted(list(idx_cache.keys()))
idx = keys[-1] + 1 idx = keys[-1] + 1
print(f"Using index: {idx}")
# I know there's something to pad I don't care # I know there's something to pad I don't care
pad = "" pad = ""
if idx < 10000: for i in range(4,0,-1):
pad = f"{pad}0" if idx < 10 ** i:
if idx < 1000: pad = f"{pad}0"
pad = f"{pad}0"
if idx < 100:
pad = f"{pad}0"
if idx < 10:
pad = f"{pad}0"
idx = f"{pad}{idx}" idx = f"{pad}{idx}"
def get_name(line=0, candidate=0, combined=False): def get_name(line=0, candidate=0, combined=False):