added torchscale XMOE integration (because Mixtral 8x7B seems very promising and I want to see if it works)
This commit is contained in:
parent
6c51a629cc
commit
9c198eb75a
|
@ -169,6 +169,7 @@ class Model:
|
||||||
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
|
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
|
||||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||||
langs: int = 1 # defined languages
|
langs: int = 1 # defined languages
|
||||||
|
experts: int = 1
|
||||||
arch_type: str = "retnet" # or "transformer""
|
arch_type: str = "retnet" # or "transformer""
|
||||||
training: bool = True # unneeded now
|
training: bool = True # unneeded now
|
||||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
|
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
|
||||||
|
@ -183,12 +184,16 @@ class Model:
|
||||||
name.append(self.size)
|
name.append(self.size)
|
||||||
|
|
||||||
if self.arch_type != "transformer":
|
if self.arch_type != "transformer":
|
||||||
name.append(self.arch_type.replace("/", "-"))
|
if self.experts:
|
||||||
|
name.append(f'{self.experts}x'+self.arch_type.replace("/", "-"))
|
||||||
|
else:
|
||||||
|
name.append(self.arch_type.replace("/", "-"))
|
||||||
|
|
||||||
if self.interleave:
|
if self.interleave:
|
||||||
name.append("interleaved")
|
name.append("interleaved")
|
||||||
|
else:
|
||||||
|
name.append(f'{cfg.models.prom_levels}')
|
||||||
|
|
||||||
name.append(f'{cfg.models.prom_levels}')
|
|
||||||
|
|
||||||
return "-".join(name)
|
return "-".join(name)
|
||||||
|
|
||||||
|
@ -247,8 +252,8 @@ class Models:
|
||||||
_prom_levels: int = 1
|
_prom_levels: int = 1
|
||||||
|
|
||||||
_models: list[Model] = field(default_factory=lambda: [
|
_models: list[Model] = field(default_factory=lambda: [
|
||||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
|
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),
|
||||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
|
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),
|
||||||
])
|
])
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
|
|
|
@ -224,7 +224,7 @@ class Dataset(_Dataset):
|
||||||
self.spkrs_by_spkr_group[spkr_group].append( spkr )
|
self.spkrs_by_spkr_group[spkr_group].append( spkr )
|
||||||
|
|
||||||
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
||||||
|
|
||||||
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
|
@ -351,7 +351,7 @@ class Dataset(_Dataset):
|
||||||
# shuffle it up a bit
|
# shuffle it up a bit
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
if cfg.experimental:
|
if cfg.experimental:
|
||||||
trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
trim_length = random.randint(75 * 3, 75 * 6) # [3 seconds, 6 seconds]
|
||||||
#trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * 75))
|
#trim_length = max(2, int(np.random.normal(loc=5, scale=1.25) * 75))
|
||||||
else:
|
else:
|
||||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||||
|
|
|
@ -45,7 +45,7 @@ from .base import TrainFeeder
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if not distributed_initialized() and cfg.trainer.backend == "local" and world_size() > 1:
|
if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1:
|
||||||
init_distributed(torch.distributed.init_process_group)
|
init_distributed(torch.distributed.init_process_group)
|
||||||
|
|
||||||
# A very naive engine implementation using barebones PyTorch
|
# A very naive engine implementation using barebones PyTorch
|
||||||
|
|
|
@ -119,10 +119,23 @@ class AR_NAR(Base):
|
||||||
# is training
|
# is training
|
||||||
if n_levels == self.n_resp_levels:
|
if n_levels == self.n_resp_levels:
|
||||||
# might be better to have this decided on the dataloader level
|
# might be better to have this decided on the dataloader level
|
||||||
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
|
|
||||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
if cfg.experimental:
|
||||||
|
# makes higher levels less likely
|
||||||
|
def generate( lo=0, hi=8 ):
|
||||||
|
index = lo
|
||||||
|
p = random.random()
|
||||||
|
for i in range(lo, hi):
|
||||||
|
if p < 1.0 / (2 ** i):
|
||||||
|
index = i
|
||||||
|
return int(index)
|
||||||
|
|
||||||
|
quant_levels = torch.Tensor([ generate(0, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
|
||||||
else:
|
else:
|
||||||
quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])
|
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
|
||||||
|
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
|
else:
|
||||||
|
quant_levels = torch.Tensor([ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ])
|
||||||
|
|
||||||
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
||||||
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding
|
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding
|
||||||
|
@ -311,11 +324,21 @@ def example_usage():
|
||||||
proms_list = proms_list[:1]
|
proms_list = proms_list[:1]
|
||||||
resps_list = resps_list[:1]
|
resps_list = resps_list[:1]
|
||||||
|
|
||||||
|
"""
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'n_tokens': 1024,
|
'n_tokens': 1024,
|
||||||
'd_model': 1024, # 1536
|
'd_model': 1024, # 256, # 1024, # 1536
|
||||||
'n_heads': 16, # 24
|
'n_heads': 16, # 4, # 16, # 24
|
||||||
'n_layers': 12, # 32
|
'n_layers': 12, # 32
|
||||||
|
'n_experts': 8,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
kwargs = {
|
||||||
|
'n_tokens': 1024,
|
||||||
|
'd_model': 256,
|
||||||
|
'n_heads': 4,
|
||||||
|
'n_layers': 12,
|
||||||
|
'n_experts': 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -204,6 +204,8 @@ class Base(nn.Module):
|
||||||
n_layers: int = 12,
|
n_layers: int = 12,
|
||||||
p_dropout: float = 0.1,
|
p_dropout: float = 0.1,
|
||||||
|
|
||||||
|
n_experts: int=1,
|
||||||
|
|
||||||
config = None,
|
config = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -214,6 +216,7 @@ class Base(nn.Module):
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
|
self.n_experts = n_experts
|
||||||
|
|
||||||
# +1 to include the stop token
|
# +1 to include the stop token
|
||||||
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
|
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
|
||||||
|
@ -272,6 +275,11 @@ class Base(nn.Module):
|
||||||
decoder_normalize_before=True,
|
decoder_normalize_before=True,
|
||||||
|
|
||||||
rotary_embedding_base=self.rotary_embedding_base, # 10000
|
rotary_embedding_base=self.rotary_embedding_base, # 10000
|
||||||
|
|
||||||
|
# MoE
|
||||||
|
use_xmoe=n_experts>1,
|
||||||
|
moe_freq=2,
|
||||||
|
moe_expert_count=n_experts,
|
||||||
))
|
))
|
||||||
|
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
|
|
|
@ -16,6 +16,7 @@ def get_free_port():
|
||||||
|
|
||||||
_distributed_initialized = False
|
_distributed_initialized = False
|
||||||
def init_distributed( fn, *args, **kwargs ):
|
def init_distributed( fn, *args, **kwargs ):
|
||||||
|
print("Initializing distributed...")
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
_distributed_initialized = True
|
_distributed_initialized = True
|
||||||
|
|
||||||
|
|
|
@ -77,6 +77,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
# I'm very sure I can procedurally generate this list
|
# I'm very sure I can procedurally generate this list
|
||||||
parser.add_argument("--text", type=str, default=kwargs["text"])
|
parser.add_argument("--text", type=str, default=kwargs["text"])
|
||||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||||
|
parser.add_argument("--language", type=str, default="en")
|
||||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||||
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75))
|
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75))
|
||||||
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*75))
|
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*75))
|
||||||
|
@ -104,6 +105,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
wav, sr = tts.inference(
|
wav, sr = tts.inference(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
|
language=args.language,
|
||||||
references=[args.references.split(";")],
|
references=[args.references.split(";")],
|
||||||
out_path=tmp.name,
|
out_path=tmp.name,
|
||||||
max_ar_steps=args.max_ar_steps,
|
max_ar_steps=args.max_ar_steps,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user