diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index b8f13ec..9c35962 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -23,10 +23,10 @@ def get_device_name(): name = 'cpu' - if has_dml(): - name = 'dml' - elif torch.cuda.is_available(): + if torch.cuda.is_available(): name = 'cuda' + elif has_dml(): + name = 'dml' return name @@ -67,8 +67,7 @@ def get_device_batch_size(): return 4 return 1 -def get_device_count(): - name = get_device_name() +def get_device_count(name=get_device_name()): if name == "cuda": return torch.cuda.device_count() if name == "dml": diff --git a/webui.py b/webui.py index 73ebb67..b845531 100755 --- a/webui.py +++ b/webui.py @@ -155,12 +155,10 @@ def generate( idx_cache = {} for i, file in enumerate(os.listdir(outdir)): filename = os.path.basename(file) - if filename[-5:] == ".json": - match = re.findall(rf"^{voice}_(\d+)(?:.+?)\.json$", filename) - elif filename[-4:] == ".wav": - match = re.findall(rf"^{voice}_(\d+)(?:.+?)\.wav$", filename) - else: + extension = os.path.splitext(filename)[1] + if extension != ".json" and extension != ".wav": continue + match = re.findall(rf"^{voice}_(\d+)(?:.+?)?{extension}$", filename) key = int(match[0]) idx_cache[key] = True @@ -169,18 +167,11 @@ def generate( keys = sorted(list(idx_cache.keys())) idx = keys[-1] + 1 - print(f"Using index: {idx}") - # I know there's something to pad I don't care pad = "" - if idx < 10000: - pad = f"{pad}0" - if idx < 1000: - pad = f"{pad}0" - if idx < 100: - pad = f"{pad}0" - if idx < 10: - pad = f"{pad}0" + for i in range(4,0,-1): + if idx < 10 ** i: + pad = f"{pad}0" idx = f"{pad}{idx}" def get_name(line=0, candidate=0, combined=False):