Windows specific fixes (to-do: find libespeak-ng.dll automatically because it cannot be trusted to do it by default)

This commit is contained in:
mrq 2024-11-03 19:19:15 -06:00
parent d229725c76
commit c83670c38c
6 changed files with 32 additions and 5 deletions

View File

@ -37,7 +37,7 @@ setup(
packages=find_packages(),
install_requires=(
# training backends
["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else [])
["deepspeed>=0.7.7"] if not sys.platform.startswith("win") else ["psutil"])
+ [
# logging niceties
"coloredlogs>=15.0.1",

View File

@ -948,6 +948,14 @@ class Config(BaseConfig):
_logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}")
del model["interleave"]
if "p_rvq_levels" in model["experimental"]:
model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"]
del model["experimental"]["p_rvq_levels"]
if "p_len_train" in model["experimental"]:
model["experimental"]["len_train_p"] = model["experimental"]["p_len_train"]
del model["experimental"]["p_len_train"]
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]

View File

@ -156,6 +156,12 @@ def main():
comparison_kwargs["disabled"]["layer_skip"] = False
comparison_kwargs["enabled"]["layer_skip"] = True
elif args.comparison == "refine-on-stop":
comparison_kwargs["suffix"] = "refine-on-stop"
comparison_kwargs["titles"] = [f"Without Ro<S>", "With Ro<S>"]
comparison_kwargs["disabled"]["refine_on_stop"] = False
comparison_kwargs["enabled"]["refine_on_stop"] = True
elif args.comparison == "ar-temp":
current_temp = args.ar_temp
other_temp = 1.0

View File

@ -48,8 +48,12 @@ from ..utils import wrapper as ml
_logger = logging.getLogger(__name__)
# windows throws an error here
try:
if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1:
init_distributed(torch.distributed.init_process_group)
except Exception as e:
pass
# A very naive engine implementation using barebones PyTorch
class Engine():

View File

@ -5,7 +5,7 @@ import torch
import logging
import random
from typing import Literal, overload, Optional, Tuple, Union, List, Unpack
from typing import Literal, overload, Optional, Tuple, Union, List
from torch import Tensor, nn
from transformers.cache_utils import Cache

View File

@ -1,4 +1,5 @@
import os
import sys
import re
import math
import argparse
@ -22,6 +23,8 @@ from .emb.qnt import decode_to_wave
from .data import get_lang_symmap, get_random_prompt
is_windows = sys.platform.startswith("win")
tts = None
layout = {}
@ -68,6 +71,9 @@ def get_model_paths( paths=[Path("./training/"), Path("./models/"), Path("./data
continue
configs.append( sft )
if is_windows:
configs = [ str(p) for p in configs ]
return configs
def get_dtypes():
@ -199,6 +205,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--refine-on-stop", action="store_true")
args, unknown = parser.parse_known_args()
if is_windows:
tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
else:
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
"""