This commit is contained in:
mrq 2024-06-05 20:53:10 -05:00
parent ff6fe6f1bc
commit 4073656293
2 changed files with 17 additions and 8 deletions

View File

@ -245,11 +245,10 @@ class Model:
elif isinstance(self.size, str) and self.size != "full":
name.append(self.size)
if self.arch_type != "transformer":
if self.experts > 1:
name.append(f'{self.experts}x'+self.arch_type.replace("/", "-"))
else:
name.append(self.arch_type.replace("/", "-"))
if self.experts > 1:
name.append(f'{self.experts}x'+self.arch_type.replace("/", "-"))
else:
name.append(self.arch_type.replace("/", "-"))
if cfg.optimizations.bitnet:
name.append("bitnet")
@ -264,11 +263,20 @@ class Model:
@property
def tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "tokens"):
return self.size['tokens']
return self.audio_tokens
@property
def audio_tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "audio_tokens"):
return self.size['audio_tokens']
return 1024
@property
def text_tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"):
return self.size['text_tokens']
return 256
@property
def dim(self):
if isinstance(self.size, dict) and hasattr(self.size, "dim"):

View File

@ -639,7 +639,8 @@ class Base(nn.Module):
if text_list is not None:
inputs[i].append( ( "text", text_list[i] ) )
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
if self.rvq_level_emb is not None:
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
if proms_list is not None:
inputs[i].append( ( "prom", proms_list[i] ) )