added option to load lora directly from the model file itself with --lora

This commit is contained in:
mrq 2024-10-26 00:13:10 -05:00
parent 023c3af331
commit 4049f51ba9
6 changed files with 52 additions and 22 deletions

View File

@ -17,6 +17,8 @@ def main():
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--model", type=Path, default=None)
parser.add_argument("--lora", type=Path, default=None)
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-nar-levels", type=int, default=7)
@ -53,7 +55,14 @@ def main():
parser.add_argument("--attention", type=str, default=None)
args = parser.parse_args()
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
config = None
if args.yaml:
config = args.yaml
elif args.model:
config = args.model
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
output = tts.inference(
text=args.text,
references=args.references,

View File

@ -27,15 +27,11 @@ from .utils import set_seed, prune_missing
@dataclass()
class BaseConfig:
yaml_path: str | None = None # path passed in through --yaml
model_path: str | None = None # path passed in through --model
@property
def cfg_path(self):
if self.yaml_path:
return Path(self.yaml_path.parent)
if self.model_path:
return Path(self.model_path.parent)
return Path(__file__).parent.parent / "data"
@ -114,14 +110,15 @@ class BaseConfig:
return cls(**state)
@classmethod
def from_model( cls, model_path ):
def from_model( cls, model_path, lora_path=None ):
if not model_path.exists():
raise Exception(f'Model path does not exist: {model_path}')
# load state dict and copy its stored model config
state_dict = torch_load( model_path )["config"]
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } ] if model_path and model_path.exists() else []
lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else []
state = { "models": [ state_dict ], "trainer": { "load_state_dict": True }, "model_path": model_path }
state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } }
return cls(**state)
@classmethod
@ -134,10 +131,11 @@ class BaseConfig:
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--lora", type=Path, default=os.environ.get('VALLE_LORA', None)) # os environ so it can be specified in a HuggingFace Space too
args, unknown = parser.parse_known_args(args=args)
if args.model:
return cls.from_model( args.model )
return cls.from_model( args.model, args.lora )
if args.yaml:
return cls.from_yaml( args.yaml )
@ -276,6 +274,7 @@ class Model:
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
attention: str = "auto" # for llama arch_types: attention used
dropout: float = 0.1 # adjustable dropout value
path: Path | None = None
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
loss_factors: dict = field(default_factory=lambda: {})
capabilities: list = field(default_factory=lambda: ["ar", "nar"]) # + ["lang", "tone"] if you have your dataset labeled for such
@ -408,6 +407,7 @@ class LoRA:
embeddings: bool = False # train the embedding too
parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
path: Path | None = None
@property
def full_name(self):
@ -832,8 +832,8 @@ class Config(BaseConfig):
tmp = Config.from_yaml( config_path )
self.__dict__.update(tmp.__dict__)
def load_model( self, config_path ):
tmp = Config.from_model( config_path )
def load_model( self, config_path, lora_path=None ):
tmp = Config.from_model( config_path, lora_path )
self.__dict__.update(tmp.__dict__)
def load_hdf5( self, write=False ):
@ -870,6 +870,9 @@ class Config(BaseConfig):
def format( self, training=True ):
print( self.models )
print( self.loras )
if isinstance(self.dataset, type):
self.dataset = dict()
@ -949,7 +952,6 @@ class Config(BaseConfig):
_logger.warning(f"Deprecated flag found: {'cfg.model.p_rvq_levels'}")
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]

View File

@ -42,6 +42,8 @@ def main():
parser = argparse.ArgumentParser("VALL-E TTS Demo")
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--model", type=Path, default=None)
parser.add_argument("--lora", type=Path, default=None)
parser.add_argument("--demo-dir", type=Path, default=None)
parser.add_argument("--skip-existing", action="store_true")
@ -93,8 +95,14 @@ def main():
parser.add_argument("--comparison", type=str, default=None)
args = parser.parse_args()
config = None
if args.yaml:
config = args.yaml
elif args.model:
config = args.model
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp )
if not args.demo_dir:
args.demo_dir = Path("./data/demo/")

View File

@ -58,8 +58,8 @@ def load_engines(training=True, **model_kwargs):
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
# if loaded using --model=
if cfg.model_path and cfg.model_path.exists():
load_path = cfg.model_path
if model.config.path and model.config.path.exists():
load_path = model.config.path
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
_logger.warning(f"Checkpoint missing, but weights found: {load_path}")
@ -208,7 +208,11 @@ def load_engines(training=True, **model_kwargs):
# load lora weights if exists
if cfg.lora is not None:
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if cfg.lora.path:
lora_path = cfg.lora.path
else:
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if lora_path.exists():
_logger.info( f"Loaded LoRA state dict: {lora_path}" )

View File

@ -25,16 +25,16 @@ if deepspeed_available:
import deepspeed
class TTS():
def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ):
def __init__( self, config=None, lora=None, device=None, amp=None, dtype=None, attention=None ):
self.loading = True
# yes I can just grab **kwargs and forward them here
self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention )
self.load_config( config=config, lora=lora, device=device, amp=amp, dtype=dtype, attention=attention )
self.load_model()
self.loading = False
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
def load_config( self, config=None, lora=None, device=None, amp=None, dtype=None, attention=None ):
if not config:
download_model()
config = DEFAULT_MODEL_PATH
@ -44,7 +44,7 @@ class TTS():
cfg.load_yaml( config )
elif config.suffix == ".sft":
_logger.info(f"Loading model: {config}")
cfg.load_model( config )
cfg.load_model( config, lora )
else:
raise Exception(f"Unknown config passed: {config}")

View File

@ -112,7 +112,7 @@ def load_sample( speaker ):
return data, (sr, wav)
def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention=None):
def init_tts(config=None, lora=None, restart=False, device="cuda", dtype="auto", attention=None):
global tts
if tts is not None:
@ -125,6 +125,7 @@ def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention=
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--model", type=Path, default=os.environ.get('VALLE_MODEL', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--lora", type=Path, default=os.environ.get('VALLE_LORA', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--device", type=str, default=device)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=dtype)
@ -137,12 +138,18 @@ def init_tts(config=None, restart=False, device="cuda", dtype="auto", attention=
elif config.suffix == ".sft" and not args.model:
args.model = config
if lora and not args.lora:
args.lora = lora
if args.yaml:
config = args.yaml
elif args.model:
config = args.model
tts = TTS( config=config, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
if args.lora:
lora = args.lora
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
return tts
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())