fixing an error I caught while fixing tortoise_tts, possibly actually load a LoRA if not passing a yaml/model
This commit is contained in:
parent
98b357cc53
commit
7f4206a879
@ -102,10 +102,18 @@ class BaseConfig:
|
||||
return yaml
|
||||
|
||||
@classmethod
|
||||
def from_yaml( cls, yaml_path ):
|
||||
def from_yaml( cls, yaml_path, lora_path=None ):
|
||||
state = {}
|
||||
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
|
||||
state.setdefault("yaml_path", yaml_path)
|
||||
|
||||
if lora_path:
|
||||
if not lora_path.exists():
|
||||
raise Exception(f'LoRA path does not exist: {lora_path}')
|
||||
|
||||
lora_state_dict = torch_load( lora_path ) if lora_path and lora_path.exists() else None
|
||||
state["loras"] = [ lora_state_dict["config"] | { "path": lora_path } ] if lora_state_dict is not None else []
|
||||
|
||||
state = cls.prune_missing( state )
|
||||
return cls(**state)
|
||||
|
||||
@ -157,7 +165,7 @@ class BaseConfig:
|
||||
return cls.from_model( args.model, args.lora )
|
||||
|
||||
if args.yaml:
|
||||
return cls.from_yaml( args.yaml )
|
||||
return cls.from_yaml( args.yaml, args.lora )
|
||||
|
||||
return cls(**{})
|
||||
|
||||
@ -981,8 +989,8 @@ class Config(BaseConfig):
|
||||
"""
|
||||
|
||||
# this gets called from vall_e.inference
|
||||
def load_yaml( self, config_path ):
|
||||
tmp = Config.from_yaml( config_path )
|
||||
def load_yaml( self, config_path, lora_path=None ):
|
||||
tmp = Config.from_yaml( config_path, lora_path )
|
||||
self.__dict__.update(tmp.__dict__)
|
||||
|
||||
def load_model( self, config_path, lora_path=None ):
|
||||
|
||||
@ -882,7 +882,7 @@ class Dataset(_Dataset):
|
||||
# split to retain tuples
|
||||
flattened[bucket] = self.duration_buckets[bucket]
|
||||
# replace with path
|
||||
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
|
||||
# flattened[bucket] = [ x[0] for x in flattened[bucket] ]
|
||||
# flatten by paths
|
||||
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], lambda x: x[0])]
|
||||
# flatten paths
|
||||
|
||||
@ -51,7 +51,7 @@ class TTS():
|
||||
|
||||
if config.suffix == ".yaml":
|
||||
_logger.info(f"Loading YAML: {config}")
|
||||
cfg.load_yaml( config )
|
||||
cfg.load_yaml( config, lora )
|
||||
elif config.suffix == ".sft":
|
||||
_logger.info(f"Loading model: {config}")
|
||||
cfg.load_model( config, lora )
|
||||
|
||||
Loading…
Reference in New Issue
Block a user