oops
This commit is contained in:
parent
ff6fe6f1bc
commit
4073656293
|
@ -245,11 +245,10 @@ class Model:
|
|||
elif isinstance(self.size, str) and self.size != "full":
|
||||
name.append(self.size)
|
||||
|
||||
if self.arch_type != "transformer":
|
||||
if self.experts > 1:
|
||||
name.append(f'{self.experts}x'+self.arch_type.replace("/", "-"))
|
||||
else:
|
||||
name.append(self.arch_type.replace("/", "-"))
|
||||
if self.experts > 1:
|
||||
name.append(f'{self.experts}x'+self.arch_type.replace("/", "-"))
|
||||
else:
|
||||
name.append(self.arch_type.replace("/", "-"))
|
||||
|
||||
if cfg.optimizations.bitnet:
|
||||
name.append("bitnet")
|
||||
|
@ -264,11 +263,20 @@ class Model:
|
|||
|
||||
@property
|
||||
def tokens(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "tokens"):
|
||||
return self.size['tokens']
|
||||
return self.audio_tokens
|
||||
|
||||
@property
|
||||
def audio_tokens(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "audio_tokens"):
|
||||
return self.size['audio_tokens']
|
||||
return 1024
|
||||
|
||||
@property
|
||||
def text_tokens(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"):
|
||||
return self.size['text_tokens']
|
||||
return 256
|
||||
|
||||
@property
|
||||
def dim(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "dim"):
|
||||
|
|
|
@ -639,7 +639,8 @@ class Base(nn.Module):
|
|||
if text_list is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
|
||||
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
|
||||
if self.rvq_level_emb is not None:
|
||||
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
|
||||
|
||||
if proms_list is not None:
|
||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||
|
|
Loading…
Reference in New Issue
Block a user