oops
This commit is contained in:
parent
ff6fe6f1bc
commit
4073656293
|
@ -245,11 +245,10 @@ class Model:
|
||||||
elif isinstance(self.size, str) and self.size != "full":
|
elif isinstance(self.size, str) and self.size != "full":
|
||||||
name.append(self.size)
|
name.append(self.size)
|
||||||
|
|
||||||
if self.arch_type != "transformer":
|
if self.experts > 1:
|
||||||
if self.experts > 1:
|
name.append(f'{self.experts}x'+self.arch_type.replace("/", "-"))
|
||||||
name.append(f'{self.experts}x'+self.arch_type.replace("/", "-"))
|
else:
|
||||||
else:
|
name.append(self.arch_type.replace("/", "-"))
|
||||||
name.append(self.arch_type.replace("/", "-"))
|
|
||||||
|
|
||||||
if cfg.optimizations.bitnet:
|
if cfg.optimizations.bitnet:
|
||||||
name.append("bitnet")
|
name.append("bitnet")
|
||||||
|
@ -264,11 +263,20 @@ class Model:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tokens(self):
|
def tokens(self):
|
||||||
if isinstance(self.size, dict) and hasattr(self.size, "tokens"):
|
return self.audio_tokens
|
||||||
return self.size['tokens']
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_tokens(self):
|
||||||
|
if isinstance(self.size, dict) and hasattr(self.size, "audio_tokens"):
|
||||||
|
return self.size['audio_tokens']
|
||||||
return 1024
|
return 1024
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text_tokens(self):
|
||||||
|
if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"):
|
||||||
|
return self.size['text_tokens']
|
||||||
|
return 256
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dim(self):
|
def dim(self):
|
||||||
if isinstance(self.size, dict) and hasattr(self.size, "dim"):
|
if isinstance(self.size, dict) and hasattr(self.size, "dim"):
|
||||||
|
|
|
@ -639,7 +639,8 @@ class Base(nn.Module):
|
||||||
if text_list is not None:
|
if text_list is not None:
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
|
|
||||||
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
|
if self.rvq_level_emb is not None:
|
||||||
|
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
|
||||||
|
|
||||||
if proms_list is not None:
|
if proms_list is not None:
|
||||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||||
|
|
Loading…
Reference in New Issue
Block a user