fixed not being able to use other resnets as a base
This commit is contained in:
parent
6eb65326c3
commit
16ad4fa1c9
|
@ -35,6 +35,7 @@ def main():
|
|||
parser.add_argument("--base64", type=str)
|
||||
parser.add_argument("--write", type=Path)
|
||||
parser.add_argument("--temp", type=float, default=1.0)
|
||||
parser.add_argument("--limit", type=int, default=0)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
images = []
|
||||
|
@ -43,9 +44,13 @@ def main():
|
|||
for p in args.path.rglob("./*.jpg"):
|
||||
image = Image.open(p).convert('RGB')
|
||||
images.append(image)
|
||||
if args.limit and len(images) >= args.limit:
|
||||
break
|
||||
for p in args.path.rglob("./*.png"):
|
||||
image = Image.open(p).convert('RGB')
|
||||
images.append(image)
|
||||
if args.limit and len(images) >= args.limit:
|
||||
break
|
||||
else:
|
||||
image = Image.open(args.path).convert('RGB')
|
||||
images.append(image)
|
||||
|
|
|
@ -158,13 +158,13 @@ class Dataset(_Dataset):
|
|||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
tokens = tokenize( path.stem.upper() )
|
||||
text = torch.tensor( tokens ).to(dtype=torch.uint8)
|
||||
text = path.stem.upper()
|
||||
|
||||
image = Image.open(path).convert('RGB')
|
||||
width, height = image.size
|
||||
|
||||
image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB
|
||||
|
||||
text = torch.tensor( tokenize( text ) ).to(dtype=torch.uint8)
|
||||
image = self.transform(image).to(dtype=self.image_dtype)
|
||||
|
||||
return dict(
|
||||
index=index,
|
||||
|
@ -215,36 +215,24 @@ def _create_dataloader(dataset, training):
|
|||
)
|
||||
|
||||
def _load_train_val_paths( val_ratio=0.1 ):
|
||||
paths = []
|
||||
train_paths = []
|
||||
val_paths = []
|
||||
|
||||
for data_dir in cfg.dataset.training:
|
||||
paths.extend(data_dir.rglob("*.jpg"))
|
||||
paths.extend(data_dir.rglob("*.png"))
|
||||
train_paths.extend(data_dir.rglob("*.jpg"))
|
||||
train_paths.extend(data_dir.rglob("*.png"))
|
||||
|
||||
if len(paths) > 0:
|
||||
if len(train_paths) > 0:
|
||||
random.seed(0)
|
||||
random.shuffle(paths)
|
||||
train_paths.extend(paths)
|
||||
random.shuffle(train_paths)
|
||||
|
||||
if len(cfg.dataset.validation) == 0:
|
||||
val_len = math.floor(len(train_paths) * val_ratio)
|
||||
train_len = math.floor(len(train_paths) * (1 - val_ratio))
|
||||
for data_dir in cfg.dataset.validation:
|
||||
val_paths.extend(data_dir.rglob("*.jpg"))
|
||||
val_paths.extend(data_dir.rglob("*.png"))
|
||||
|
||||
val_paths = train_paths[:-val_len]
|
||||
train_paths = train_paths[:train_len]
|
||||
else:
|
||||
paths = []
|
||||
|
||||
for data_dir in cfg.dataset.validation:
|
||||
paths.extend(data_dir.rglob("*.jpg"))
|
||||
paths.extend(data_dir.rglob("*.png"))
|
||||
|
||||
if len(paths) > 0:
|
||||
random.seed(0)
|
||||
random.shuffle(paths)
|
||||
val_paths.extend(paths)
|
||||
if len(val_paths) > 0:
|
||||
random.seed(0)
|
||||
random.shuffle(val_paths)
|
||||
|
||||
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
||||
|
||||
|
@ -595,4 +583,7 @@ if __name__ == "__main__":
|
|||
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
|
||||
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
||||
break
|
||||
"""
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
...
|
|
@ -35,7 +35,7 @@ class Model(nn.Module):
|
|||
|
||||
self.n_tokens = n_tokens
|
||||
self.n_len = n_len + 2 # start/stop tokens
|
||||
self.d_model = d_model
|
||||
# self.d_model = d_model
|
||||
self.d_resnet = d_resnet
|
||||
|
||||
ResNet = resnet18
|
||||
|
@ -56,7 +56,7 @@ class Model(nn.Module):
|
|||
ResNet = resnet152
|
||||
|
||||
self.resnet = ResNet(pretrained=False)
|
||||
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
|
||||
self.resnet.fc = nn.Linear( self.resnet.fc.in_features, self.n_tokens * self.n_len )
|
||||
|
||||
self.accuracy_metric = MulticlassAccuracy(
|
||||
n_tokens,
|
||||
|
@ -96,12 +96,15 @@ class Model(nn.Module):
|
|||
nll = sum( loss ) / len( loss ),
|
||||
)
|
||||
|
||||
self.stats = dict(
|
||||
acc = self.accuracy_metric( pred, labels ),
|
||||
precision = self.precision_metric( pred, labels ),
|
||||
)
|
||||
try:
|
||||
self.stats = dict(
|
||||
acc = self.accuracy_metric( pred, labels ),
|
||||
precision = self.precision_metric( pred, labels ),
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
|
||||
|
||||
return answer
|
||||
return answer
|
||||
|
|
Loading…
Reference in New Issue
Block a user