vall-e/vall_e/vall_e/__init__.py
2023-01-12 19:55:04 +08:00

41 lines
849 B
Python

from ..config import cfg
from .ar import AR
from .nar import NAR
def get_model(name):
name = name.lower()
if name.startswith("ar"):
Model = AR
elif name.startswith("nar"):
Model = NAR
else:
raise ValueError("Model name should start with AR or NAR.")
if "-quarter" in name:
model = Model(
cfg.num_tokens,
d_model=256,
n_heads=4,
n_layers=12,
)
elif "-half" in name:
model = Model(
cfg.num_tokens,
d_model=512,
n_heads=8,
n_layers=12,
)
elif "-official" in name:
model = Model(
cfg.num_tokens,
d_model=1024,
n_heads=16,
n_layers=12,
)
else:
raise NotImplementedError(name)
return model