fixing an error I caught while fixing tortoise_tts, possibly actually load a LoRA if not passing a yaml/model

This commit is contained in:
mrq 2025-07-24 20:56:09 -05:00
parent 98b357cc53
commit 7f4206a879
3 changed files with 14 additions and 6 deletions

View File

@ -102,10 +102,18 @@ class BaseConfig:
return yaml
@classmethod
def from_yaml( cls, yaml_path ):
def from_yaml( cls, yaml_path, lora_path=None ):
state = {}
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
state.setdefault("yaml_path", yaml_path)
if lora_path:
if not lora_path.exists():
raise Exception(f'LoRA path does not exist: {lora_path}')
lora_state_dict = torch_load( lora_path ) if lora_path and lora_path.exists() else None
state["loras"] = [ lora_state_dict["config"] | { "path": lora_path } ] if lora_state_dict is not None else []
state = cls.prune_missing( state )
return cls(**state)
@ -157,7 +165,7 @@ class BaseConfig:
return cls.from_model( args.model, args.lora )
if args.yaml:
return cls.from_yaml( args.yaml )
return cls.from_yaml( args.yaml, args.lora )
return cls(**{})
@ -981,8 +989,8 @@ class Config(BaseConfig):
"""
# this gets called from vall_e.inference
def load_yaml( self, config_path ):
tmp = Config.from_yaml( config_path )
def load_yaml( self, config_path, lora_path=None ):
tmp = Config.from_yaml( config_path, lora_path )
self.__dict__.update(tmp.__dict__)
def load_model( self, config_path, lora_path=None ):

View File

@ -882,7 +882,7 @@ class Dataset(_Dataset):
# split to retain tuples
flattened[bucket] = self.duration_buckets[bucket]
# replace with path
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
# flattened[bucket] = [ x[0] for x in flattened[bucket] ]
# flatten by paths
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], lambda x: x[0])]
# flatten paths

View File

@ -51,7 +51,7 @@ class TTS():
if config.suffix == ".yaml":
_logger.info(f"Loading YAML: {config}")
cfg.load_yaml( config )
cfg.load_yaml( config, lora )
elif config.suffix == ".sft":
_logger.info(f"Loading model: {config}")
cfg.load_model( config, lora )