added option to load lora directly from the model file itself with --lora
This commit is contained in:
parent
023c3af331
commit
4049f51ba9
|
@ -17,6 +17,8 @@ def main():
|
|||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--model", type=Path, default=None)
|
||||
parser.add_argument("--lora", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
|
||||
parser.add_argument("--max-nar-levels", type=int, default=7)
|
||||
|
@ -53,7 +55,14 @@ def main():
|
|||
parser.add_argument("--attention", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
|
||||
config = None
|
||||
|
||||
if args.yaml:
|
||||
config = args.yaml
|
||||
elif args.model:
|
||||
config = args.model
|
||||
|
||||
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
|
||||
output = tts.inference(
|
||||
text=args.text,
|
||||
references=args.references,
|
||||
|
|
|
@ -27,15 +27,11 @@ from .utils import set_seed, prune_missing
|
|||
@dataclass()
|
||||
class BaseConfig:
|
||||
yaml_path: str | None = None # path passed in through --yaml
|
||||
model_path: str | None = None # path passed in through --model
|
||||
|
||||
@property
|
||||
def cfg_path(self):
|
||||
if self.yaml_path:
|
||||
return Path(self.yaml_path.parent)
|
||||
|
||||
if self.model_path:
|
||||
return Path(self.model_path.parent)
|
||||
|
||||
return Path(__file__).parent.parent / "data"
|
||||
|
||||
|
@ -114,14 +110,15 @@ class BaseConfig:
|
|||
return cls(**state)
|
||||
|
||||
@classmethod
|
||||
def from_model( cls, model_path ):
|
||||
def from_model( cls, model_path, lora_path=None ):
|
||||
if not model_path.exists():
|
||||
raise Exception(f'Model path does not exist: {model_path}')
|
||||
|
||||
# load state dict and copy its stored model config
|
||||
state_dict = torch_load( model_path )["config"]
|
||||
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } ] if model_path and model_path.exists() else []
|
||||
lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else []
|
||||
|
||||
state = { "models": [ state_dict ], "trainer": { "load_state_dict": True }, "model_path": model_path }
|
||||
state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } }
|
||||
return cls(**state)
|
||||
|
||||
@classmethod
|
||||
|
@ -134,10 +131,11 @@ class BaseConfig:
|
|||
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--lora", type=Path, default=os.environ.get('VALLE_LORA', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
args, unknown = parser.parse_known_args(args=args)
|
||||
|
||||
if args.model:
|
||||
return cls.from_model( args.model )
|
||||
return cls.from_model( args.model, args.lora )
|
||||
|
||||
if args.yaml:
|
||||
return cls.from_yaml( args.yaml )
|
||||
|
@ -276,6 +274,7 @@ class Model:
|
|||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
attention: str = "auto" # for llama arch_types: attention used
|
||||
dropout: float = 0.1 # adjustable dropout value
|
||||
path: Path | None = None
|
||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||
loss_factors: dict = field(default_factory=lambda: {})
|
||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"]) # + ["lang", "tone"] if you have your dataset labeled for such
|
||||
|
@ -408,6 +407,7 @@ class LoRA:
|
|||
embeddings: bool = False # train the embedding too
|
||||
parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
|
||||
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
||||
path: Path | None = None
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
|
@ -832,8 +832,8 @@ class Config(BaseConfig):
|
|||
tmp = Config.from_yaml( config_path )
|
||||
self.__dict__.update(tmp.__dict__)
|
||||
|
||||
def load_model( self, config_path ):
|
||||
tmp = Config.from_model( config_path )
|
||||
def load_model( self, config_path, lora_path=None ):
|
||||
tmp = Config.from_model( config_path, lora_path )
|
||||
self.__dict__.update(tmp.__dict__)
|
||||
|
||||
def load_hdf5( self, write=False ):
|
||||
|
@ -870,6 +870,9 @@ class Config(BaseConfig):
|
|||
|
||||
|
||||
def format( self, training=True ):
|
||||
print( self.models )
|
||||
print( self.loras )
|
||||
|
||||
if isinstance(self.dataset, type):
|
||||
self.dataset = dict()
|
||||
|
||||
|
@ -949,7 +952,6 @@ class Config(BaseConfig):
|
|||
_logger.warning(f"Deprecated flag found: {'cfg.model.p_rvq_levels'}")
|
||||
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
||||
|
||||
|
||||
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
|
||||
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]
|
||||
|
||||
|
|
|
@ -42,6 +42,8 @@ def main():
|
|||
parser = argparse.ArgumentParser("VALL-E TTS Demo")
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--model", type=Path, default=None)
|
||||
parser.add_argument("--lora", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--demo-dir", type=Path, default=None)
|
||||
parser.add_argument("--skip-existing", action="store_true")
|
||||
|
@ -93,8 +95,14 @@ def main():
|
|||
parser.add_argument("--comparison", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = None
|
||||
if args.yaml:
|
||||
config = args.yaml
|
||||
elif args.model:
|
||||
config = args.model
|
||||
|
||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
|
||||
if not args.demo_dir:
|
||||
args.demo_dir = Path("./data/demo/")
|
||||
|
|
|
@ -58,8 +58,8 @@ def load_engines(training=True, **model_kwargs):
|
|||
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
|
||||
# if loaded using --model=
|
||||
if cfg.model_path and cfg.model_path.exists():
|
||||
load_path = cfg.model_path
|
||||
if model.config.path and model.config.path.exists():
|
||||
load_path = model.config.path
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
_logger.warning(f"Checkpoint missing, but weights found: {load_path}")
|
||||
|
@ -208,7 +208,11 @@ def load_engines(training=True, **model_kwargs):
|
|||
|
||||
# load lora weights if exists
|
||||
if cfg.lora is not None:
|
||||
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
if cfg.lora.path:
|
||||
lora_path = cfg.lora.path
|
||||
else:
|
||||
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
|
||||
if lora_path.exists():
|
||||
_logger.info( f"Loaded LoRA state dict: {lora_path}" )
|
||||
|
||||
|
|
|
@ -25,16 +25,16 @@ if deepspeed_available:
|
|||
import deepspeed
|
||||
|
||||
class TTS():
|
||||
def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||
def __init__( self, config=None, lora=None, device=None, amp=None, dtype=None, attention=None ):
|
||||
self.loading = True
|
||||
|
||||
# yes I can just grab **kwargs and forward them here
|
||||
self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention )
|
||||
self.load_config( config=config, lora=lora, device=device, amp=amp, dtype=dtype, attention=attention )
|
||||
self.load_model()
|
||||
|
||||
self.loading = False
|
||||
|
||||
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||
def load_config( self, config=None, lora=None, device=None, amp=None, dtype=None, attention=None ):
|
||||
if not config:
|
||||
download_model()
|
||||
config = DEFAULT_MODEL_PATH
|
||||
|
@ -44,7 +44,7 @@ class TTS():
|
|||
cfg.load_yaml( config )
|
||||
elif config.suffix == ".sft":
|
||||
_logger.info(f"Loading model: {config}")
|
||||
cfg.load_model( config )
|
||||
cfg.load_model( config, lora )
|
||||
else:
|
||||
raise Exception(f"Unknown config passed: {config}")
|
||||
|
||||
|
|
|
@ -112,7 +112,7 @@ def load_sample( speaker ):
|
|||
|
||||
return data, (sr, wav)
|
||||
|
||||
def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention=None):
|
||||
def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto", attention=None):
|
||||
global tts
|
||||
|
||||
if tts is not None:
|
||||
|
@ -125,6 +125,7 @@ def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention=
|
|||
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--lora", type=Path, default=os.environ.get('VALLE_LORA', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--device", type=str, default=device)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default=dtype)
|
||||
|
@ -137,12 +138,18 @@ def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention=
|
|||
elif config.suffix == ".sft" and not args.model:
|
||||
args.model = config
|
||||
|
||||
if lora and not args.lora:
|
||||
args.lora = lora
|
||||
|
||||
if args.yaml:
|
||||
config = args.yaml
|
||||
elif args.model:
|
||||
config = args.model
|
||||
|
||||
tts = TTS( config=config, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
||||
if args.lora:
|
||||
lora = args.lora
|
||||
|
||||
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
||||
return tts
|
||||
|
||||
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
|
||||
|
|
Loading…
Reference in New Issue
Block a user