diff --git a/vall_e/config.py b/vall_e/config.py index d7fc2c3..6adddae 100644 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -102,10 +102,18 @@ class BaseConfig: return yaml @classmethod - def from_yaml( cls, yaml_path ): + def from_yaml( cls, yaml_path, lora_path=None ): state = {} state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8")) state.setdefault("yaml_path", yaml_path) + + if lora_path: + if not lora_path.exists(): + raise Exception(f'LoRA path does not exist: {lora_path}') + + lora_state_dict = torch_load( lora_path ) if lora_path and lora_path.exists() else None + state["loras"] = [ lora_state_dict["config"] | { "path": lora_path } ] if lora_state_dict is not None else [] + state = cls.prune_missing( state ) return cls(**state) @@ -157,7 +165,7 @@ class BaseConfig: return cls.from_model( args.model, args.lora ) if args.yaml: - return cls.from_yaml( args.yaml ) + return cls.from_yaml( args.yaml, args.lora ) return cls(**{}) @@ -981,8 +989,8 @@ class Config(BaseConfig): """ # this gets called from vall_e.inference - def load_yaml( self, config_path ): - tmp = Config.from_yaml( config_path ) + def load_yaml( self, config_path, lora_path=None ): + tmp = Config.from_yaml( config_path, lora_path ) self.__dict__.update(tmp.__dict__) def load_model( self, config_path, lora_path=None ): diff --git a/vall_e/data.py b/vall_e/data.py index 4106450..d9ab892 100644 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -882,7 +882,7 @@ class Dataset(_Dataset): # split to retain tuples flattened[bucket] = self.duration_buckets[bucket] # replace with path - flattened[bucket] = [ x[0] for x in flattened[bucket] ] + # flattened[bucket] = [ x[0] for x in flattened[bucket] ] # flatten by paths flattened[bucket] = [*_interleaved_reorder(flattened[bucket], lambda x: x[0])] # flatten paths diff --git a/vall_e/inference.py b/vall_e/inference.py index 7f41aa3..c82fe1d 100644 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -51,7 +51,7 @@ class TTS(): if config.suffix == ".yaml": _logger.info(f"Loading YAML: {config}") - cfg.load_yaml( config ) + cfg.load_yaml( config, lora ) elif config.suffix == ".sft": _logger.info(f"Loading model: {config}") cfg.load_model( config, lora )