This commit is contained in:
mrq 2024-06-05 20:53:10 -05:00
parent ff6fe6f1bc
commit 4073656293
2 changed files with 17 additions and 8 deletions

View File

@ -245,11 +245,10 @@ class Model:
elif isinstance(self.size, str) and self.size != "full": elif isinstance(self.size, str) and self.size != "full":
name.append(self.size) name.append(self.size)
if self.arch_type != "transformer": if self.experts > 1:
if self.experts > 1: name.append(f'{self.experts}x'+self.arch_type.replace("/", "-"))
name.append(f'{self.experts}x'+self.arch_type.replace("/", "-")) else:
else: name.append(self.arch_type.replace("/", "-"))
name.append(self.arch_type.replace("/", "-"))
if cfg.optimizations.bitnet: if cfg.optimizations.bitnet:
name.append("bitnet") name.append("bitnet")
@ -264,11 +263,20 @@ class Model:
@property @property
def tokens(self): def tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "tokens"): return self.audio_tokens
return self.size['tokens']
@property
def audio_tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "audio_tokens"):
return self.size['audio_tokens']
return 1024 return 1024
@property
def text_tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"):
return self.size['text_tokens']
return 256
@property @property
def dim(self): def dim(self):
if isinstance(self.size, dict) and hasattr(self.size, "dim"): if isinstance(self.size, dict) and hasattr(self.size, "dim"):

View File

@ -639,7 +639,8 @@ class Base(nn.Module):
if text_list is not None: if text_list is not None:
inputs[i].append( ( "text", text_list[i] ) ) inputs[i].append( ( "text", text_list[i] ) )
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) ) if self.rvq_level_emb is not None:
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
if proms_list is not None: if proms_list is not None:
inputs[i].append( ( "prom", proms_list[i] ) ) inputs[i].append( ( "prom", proms_list[i] ) )