Compare commits
2 Commits
0cca4eb943
...
7f4206a879
| Author | SHA1 | Date | |
|---|---|---|---|
| 7f4206a879 | |||
| 98b357cc53 |
@ -194,4 +194,10 @@ However, output leaves a lot to be desired:
|
||||
* both the small and the large model seemed to have hit a "capacity" limit
|
||||
* the "confidence" problem of the prior implementation seems to have emerged even for typical speakers
|
||||
* some other quirks and emergent behaviors inherent to the model I'm not aware of / can't recall
|
||||
* such as the demasking sampler loop being quite particular
|
||||
* such as the demasking sampler loop being quite particular
|
||||
* naturally, LoRAs are trainable:
|
||||
* at a glance it seems to address the problems of poor/inconsistent zero-shot performance
|
||||
* training a LoRA is agonizing because the loss doesn't progress anywhere near as nicely as it does against EnCodec-based models
|
||||
* however, there seems to be a problem when predicting the duration that causes it to be too short (when the input prompt is of the speaker) or too long (when the input prompt is not of the speaker)
|
||||
* simply disabling the LoRA specifically for duration prediction seems to fix this
|
||||
* additional testing against LoRAs is necessary to draw further conclusions
|
||||
@ -102,10 +102,18 @@ class BaseConfig:
|
||||
return yaml
|
||||
|
||||
@classmethod
|
||||
def from_yaml( cls, yaml_path ):
|
||||
def from_yaml( cls, yaml_path, lora_path=None ):
|
||||
state = {}
|
||||
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
|
||||
state.setdefault("yaml_path", yaml_path)
|
||||
|
||||
if lora_path:
|
||||
if not lora_path.exists():
|
||||
raise Exception(f'LoRA path does not exist: {lora_path}')
|
||||
|
||||
lora_state_dict = torch_load( lora_path ) if lora_path and lora_path.exists() else None
|
||||
state["loras"] = [ lora_state_dict["config"] | { "path": lora_path } ] if lora_state_dict is not None else []
|
||||
|
||||
state = cls.prune_missing( state )
|
||||
return cls(**state)
|
||||
|
||||
@ -157,7 +165,7 @@ class BaseConfig:
|
||||
return cls.from_model( args.model, args.lora )
|
||||
|
||||
if args.yaml:
|
||||
return cls.from_yaml( args.yaml )
|
||||
return cls.from_yaml( args.yaml, args.lora )
|
||||
|
||||
return cls(**{})
|
||||
|
||||
@ -981,8 +989,8 @@ class Config(BaseConfig):
|
||||
"""
|
||||
|
||||
# this gets called from vall_e.inference
|
||||
def load_yaml( self, config_path ):
|
||||
tmp = Config.from_yaml( config_path )
|
||||
def load_yaml( self, config_path, lora_path=None ):
|
||||
tmp = Config.from_yaml( config_path, lora_path )
|
||||
self.__dict__.update(tmp.__dict__)
|
||||
|
||||
def load_model( self, config_path, lora_path=None ):
|
||||
|
||||
@ -882,7 +882,7 @@ class Dataset(_Dataset):
|
||||
# split to retain tuples
|
||||
flattened[bucket] = self.duration_buckets[bucket]
|
||||
# replace with path
|
||||
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
|
||||
# flattened[bucket] = [ x[0] for x in flattened[bucket] ]
|
||||
# flatten by paths
|
||||
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], lambda x: x[0])]
|
||||
# flatten paths
|
||||
|
||||
@ -51,7 +51,7 @@ class TTS():
|
||||
|
||||
if config.suffix == ".yaml":
|
||||
_logger.info(f"Loading YAML: {config}")
|
||||
cfg.load_yaml( config )
|
||||
cfg.load_yaml( config, lora )
|
||||
elif config.suffix == ".sft":
|
||||
_logger.info(f"Loading model: {config}")
|
||||
cfg.load_model( config, lora )
|
||||
@ -114,6 +114,8 @@ class TTS():
|
||||
return text
|
||||
|
||||
# check if tokenizes without any unks (for example, if already phonemized text is passes)
|
||||
# to-do: properly fix this
|
||||
# - i don't remember what specific situation arised where phonemized text is already passed in to warrant the need to detect it
|
||||
"""
|
||||
if precheck and "<unk>" in self.symmap:
|
||||
tokens = tokenize( text )
|
||||
|
||||
@ -400,7 +400,9 @@ class AR_NAR_V2(Base_V2):
|
||||
batch_size = len(proms_list)
|
||||
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||
# enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||
# force disable LoRAs for this
|
||||
enable_lora( self, False )
|
||||
|
||||
task_list = [ "len" for _ in range( batch_size ) ]
|
||||
quant_levels = [ 0 for _ in range( batch_size ) ]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user