diff --git a/vall_e/config.py b/vall_e/config.py index 606f90e..ebaea54 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -169,6 +169,7 @@ class Model: prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") langs: int = 1 # defined languages + experts: int = 1 arch_type: str = "retnet" # or "transformer"" training: bool = True # unneeded now interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results) @@ -183,12 +184,16 @@ class Model: name.append(self.size) if self.arch_type != "transformer": - name.append(self.arch_type.replace("/", "-")) + if self.experts: + name.append(f'{self.experts}x'+self.arch_type.replace("/", "-")) + else: + name.append(self.arch_type.replace("/", "-")) if self.interleave: name.append("interleaved") + else: + name.append(f'{cfg.models.prom_levels}') - name.append(f'{cfg.models.prom_levels}') return "-".join(name) @@ -247,8 +252,8 @@ class Models: _prom_levels: int = 1 _models: list[Model] = field(default_factory=lambda: [ - Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, training=True, interleave=False), - Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, training=True, interleave=False), + Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False), + Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False), ]) def get(self, name=None): diff --git a/vall_e/data.py b/vall_e/data.py index c2210bc..4472afe 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -224,7 +224,7 @@ class Dataset(_Dataset): self.spkrs_by_spkr_group[spkr_group].append( spkr ) self.spkr_groups = list(self.spkrs_by_spkr_group.keys()) - + self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() } if self.sampler_type == "path": @@ -351,7 +351,7 @@ class Dataset(_Dataset): # shuffle it up a bit prom_length = 0 if cfg.experimental: - trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds] + trim_length = random.randint(75 * 3, 75 * 6) # [3 seconds, 6 seconds] #trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * 75)) else: trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 222d0af..d51a273 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -45,7 +45,7 @@ from .base import TrainFeeder _logger = logging.getLogger(__name__) -if not distributed_initialized() and cfg.trainer.backend == "local" and world_size() > 1: +if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1: init_distributed(torch.distributed.init_process_group) # A very naive engine implementation using barebones PyTorch diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index bd2fec3..7103409 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -119,10 +119,23 @@ class AR_NAR(Base): # is training if n_levels == self.n_resp_levels: # might be better to have this decided on the dataloader level - if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None: - quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + + if cfg.experimental: + # makes higher levels less likely + def generate( lo=0, hi=8 ): + index = lo + p = random.random() + for i in range(lo, hi): + if p < 1.0 / (2 ** i): + index = i + return int(index) + + quant_levels = torch.Tensor([ generate(0, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16) else: - quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ]) + if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None: + quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + else: + quant_levels = torch.Tensor([ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ]) targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target) resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding @@ -311,11 +324,21 @@ def example_usage(): proms_list = proms_list[:1] resps_list = resps_list[:1] + """ kwargs = { 'n_tokens': 1024, - 'd_model': 1024, # 1536 - 'n_heads': 16, # 24 + 'd_model': 1024, # 256, # 1024, # 1536 + 'n_heads': 16, # 4, # 16, # 24 'n_layers': 12, # 32 + 'n_experts': 8, + } + """ + kwargs = { + 'n_tokens': 1024, + 'd_model': 256, + 'n_heads': 4, + 'n_layers': 12, + 'n_experts': 1, } """ diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 34ba3f9..f97c5fc 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -204,6 +204,8 @@ class Base(nn.Module): n_layers: int = 12, p_dropout: float = 0.1, + n_experts: int=1, + config = None, ): super().__init__() @@ -214,6 +216,7 @@ class Base(nn.Module): self.d_model = d_model self.n_heads = n_heads self.n_layers = n_layers + self.n_experts = n_experts # +1 to include the stop token # to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding @@ -272,6 +275,11 @@ class Base(nn.Module): decoder_normalize_before=True, rotary_embedding_base=self.rotary_embedding_base, # 10000 + + # MoE + use_xmoe=n_experts>1, + moe_freq=2, + moe_expert_count=n_experts, )) self.classifier = nn.Linear(d_model, n_resp_tokens) diff --git a/vall_e/utils/distributed.py b/vall_e/utils/distributed.py index b6364dd..03bb859 100755 --- a/vall_e/utils/distributed.py +++ b/vall_e/utils/distributed.py @@ -16,6 +16,7 @@ def get_free_port(): _distributed_initialized = False def init_distributed( fn, *args, **kwargs ): + print("Initializing distributed...") fn(*args, **kwargs) _distributed_initialized = True diff --git a/vall_e/webui.py b/vall_e/webui.py index 5b55cbf..a613ef3 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -77,6 +77,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): # I'm very sure I can procedurally generate this list parser.add_argument("--text", type=str, default=kwargs["text"]) parser.add_argument("--references", type=str, default=kwargs["reference"]) + parser.add_argument("--language", type=str, default="en") parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75)) parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*75)) @@ -104,6 +105,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): with timer() as t: wav, sr = tts.inference( text=args.text, + language=args.language, references=[args.references.split(";")], out_path=tmp.name, max_ar_steps=args.max_ar_steps,