Compare commits

...

2 Commits

5 changed files with 26 additions and 8 deletions

View File

@ -194,4 +194,10 @@ However, output leaves a lot to be desired:
* both the small and the large model seemed to have hit a "capacity" limit
* the "confidence" problem of the prior implementation seems to have emerged even for typical speakers
* some other quirks and emergent behaviors inherent to the model I'm not aware of / can't recall
* such as the demasking sampler loop being quite particular
* such as the demasking sampler loop being quite particular
* naturally, LoRAs are trainable:
* at a glance it seems to address the problems of poor/inconsistent zero-shot performance
* training a LoRA is agonizing because the loss doesn't progress anywhere near as nicely as it does against EnCodec-based models
* however, there seems to be a problem when predicting the duration that causes it to be too short (when the input prompt is of the speaker) or too long (when the input prompt is not of the speaker)
* simply disabling the LoRA specifically for duration prediction seems to fix this
* additional testing against LoRAs is necessary to draw further conclusions

View File

@ -102,10 +102,18 @@ class BaseConfig:
return yaml
@classmethod
def from_yaml( cls, yaml_path ):
def from_yaml( cls, yaml_path, lora_path=None ):
state = {}
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
state.setdefault("yaml_path", yaml_path)
if lora_path:
if not lora_path.exists():
raise Exception(f'LoRA path does not exist: {lora_path}')
lora_state_dict = torch_load( lora_path ) if lora_path and lora_path.exists() else None
state["loras"] = [ lora_state_dict["config"] | { "path": lora_path } ] if lora_state_dict is not None else []
state = cls.prune_missing( state )
return cls(**state)
@ -157,7 +165,7 @@ class BaseConfig:
return cls.from_model( args.model, args.lora )
if args.yaml:
return cls.from_yaml( args.yaml )
return cls.from_yaml( args.yaml, args.lora )
return cls(**{})
@ -981,8 +989,8 @@ class Config(BaseConfig):
"""
# this gets called from vall_e.inference
def load_yaml( self, config_path ):
tmp = Config.from_yaml( config_path )
def load_yaml( self, config_path, lora_path=None ):
tmp = Config.from_yaml( config_path, lora_path )
self.__dict__.update(tmp.__dict__)
def load_model( self, config_path, lora_path=None ):

View File

@ -882,7 +882,7 @@ class Dataset(_Dataset):
# split to retain tuples
flattened[bucket] = self.duration_buckets[bucket]
# replace with path
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
# flattened[bucket] = [ x[0] for x in flattened[bucket] ]
# flatten by paths
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], lambda x: x[0])]
# flatten paths

View File

@ -51,7 +51,7 @@ class TTS():
if config.suffix == ".yaml":
_logger.info(f"Loading YAML: {config}")
cfg.load_yaml( config )
cfg.load_yaml( config, lora )
elif config.suffix == ".sft":
_logger.info(f"Loading model: {config}")
cfg.load_model( config, lora )
@ -114,6 +114,8 @@ class TTS():
return text
# check if tokenizes without any unks (for example, if already phonemized text is passes)
# to-do: properly fix this
# - i don't remember what specific situation arised where phonemized text is already passed in to warrant the need to detect it
"""
if precheck and "<unk>" in self.symmap:
tokens = tokenize( text )

View File

@ -400,7 +400,9 @@ class AR_NAR_V2(Base_V2):
batch_size = len(proms_list)
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
# enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
# force disable LoRAs for this
enable_lora( self, False )
task_list = [ "len" for _ in range( batch_size ) ]
quant_levels = [ 0 for _ in range( batch_size ) ]