Windows specific fixes (to-do: find libespeak-ng.dll automatically because it cannot be trusted to do it by default)

This commit is contained in:
mrq 2024-11-03 19:19:15 -06:00
parent d229725c76
commit c83670c38c
6 changed files with 32 additions and 5 deletions

View File

@ -37,7 +37,7 @@ setup(
packages=find_packages(), packages=find_packages(),
install_requires=( install_requires=(
# training backends # training backends
["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else []) ["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else ["psutil"])
+ [ + [
# logging niceties # logging niceties
"coloredlogs>=15.0.1", "coloredlogs>=15.0.1",

View File

@ -948,6 +948,14 @@ class Config(BaseConfig):
_logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}") _logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}")
del model["interleave"] del model["interleave"]
if "p_rvq_levels" in model["experimental"]:
model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"]
del model["experimental"]["p_rvq_levels"]
if "p_len_train" in model["experimental"]:
model["experimental"]["len_train_p"] = model["experimental"]["p_len_train"]
del model["experimental"]["p_len_train"]
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ] self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ] self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]

View File

@ -156,6 +156,12 @@ def main():
comparison_kwargs["disabled"]["layer_skip"] = False comparison_kwargs["disabled"]["layer_skip"] = False
comparison_kwargs["enabled"]["layer_skip"] = True comparison_kwargs["enabled"]["layer_skip"] = True
elif args.comparison == "refine-on-stop":
comparison_kwargs["suffix"] = "refine-on-stop"
comparison_kwargs["titles"] = [f"Without Ro<S>", "With Ro<S>"]
comparison_kwargs["disabled"]["refine_on_stop"] = False
comparison_kwargs["enabled"]["refine_on_stop"] = True
elif args.comparison == "ar-temp": elif args.comparison == "ar-temp":
current_temp = args.ar_temp current_temp = args.ar_temp
other_temp = 1.0 other_temp = 1.0

View File

@ -48,8 +48,12 @@ from ..utils import wrapper as ml
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1: # windows throws an error here
init_distributed(torch.distributed.init_process_group) try:
if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1:
init_distributed(torch.distributed.init_process_group)
except Exception as e:
pass
# A very naive engine implementation using barebones PyTorch # A very naive engine implementation using barebones PyTorch
class Engine(): class Engine():

View File

@ -5,7 +5,7 @@ import torch
import logging import logging
import random import random
from typing import Literal, overload, Optional, Tuple, Union, List, Unpack from typing import Literal, overload, Optional, Tuple, Union, List
from torch import Tensor, nn from torch import Tensor, nn
from transformers.cache_utils import Cache from transformers.cache_utils import Cache

View File

@ -1,4 +1,5 @@
import os import os
import sys
import re import re
import math import math
import argparse import argparse
@ -22,6 +23,8 @@ from .emb.qnt import decode_to_wave
from .data import get_lang_symmap, get_random_prompt from .data import get_lang_symmap, get_random_prompt
is_windows = sys.platform.startswith("win")
tts = None tts = None
layout = {} layout = {}
@ -68,6 +71,9 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data
continue continue
configs.append( sft ) configs.append( sft )
if is_windows:
configs = [ str(p) for p in configs ]
return configs return configs
def get_dtypes(): def get_dtypes():
@ -199,7 +205,10 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--refine-on-stop", action="store_true") parser.add_argument("--refine-on-stop", action="store_true")
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav') if is_windows:
tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
else:
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
""" """
if not args.references: if not args.references: