fixed not being able to use other resnets as a base
This commit is contained in:
parent
6eb65326c3
commit
16ad4fa1c9
|
@ -35,6 +35,7 @@ def main():
|
||||||
parser.add_argument("--base64", type=str)
|
parser.add_argument("--base64", type=str)
|
||||||
parser.add_argument("--write", type=Path)
|
parser.add_argument("--write", type=Path)
|
||||||
parser.add_argument("--temp", type=float, default=1.0)
|
parser.add_argument("--temp", type=float, default=1.0)
|
||||||
|
parser.add_argument("--limit", type=int, default=0)
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
|
@ -43,9 +44,13 @@ def main():
|
||||||
for p in args.path.rglob("./*.jpg"):
|
for p in args.path.rglob("./*.jpg"):
|
||||||
image = Image.open(p).convert('RGB')
|
image = Image.open(p).convert('RGB')
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
if args.limit and len(images) >= args.limit:
|
||||||
|
break
|
||||||
for p in args.path.rglob("./*.png"):
|
for p in args.path.rglob("./*.png"):
|
||||||
image = Image.open(p).convert('RGB')
|
image = Image.open(p).convert('RGB')
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
if args.limit and len(images) >= args.limit:
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
image = Image.open(args.path).convert('RGB')
|
image = Image.open(args.path).convert('RGB')
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
|
|
@ -158,13 +158,13 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
path = self.paths[index]
|
path = self.paths[index]
|
||||||
tokens = tokenize( path.stem.upper() )
|
text = path.stem.upper()
|
||||||
text = torch.tensor( tokens ).to(dtype=torch.uint8)
|
|
||||||
|
|
||||||
image = Image.open(path).convert('RGB')
|
image = Image.open(path).convert('RGB')
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
|
|
||||||
image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB
|
text = torch.tensor( tokenize( text ) ).to(dtype=torch.uint8)
|
||||||
|
image = self.transform(image).to(dtype=self.image_dtype)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
index=index,
|
index=index,
|
||||||
|
@ -215,36 +215,24 @@ def _create_dataloader(dataset, training):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_train_val_paths( val_ratio=0.1 ):
|
def _load_train_val_paths( val_ratio=0.1 ):
|
||||||
paths = []
|
|
||||||
train_paths = []
|
train_paths = []
|
||||||
val_paths = []
|
val_paths = []
|
||||||
|
|
||||||
for data_dir in cfg.dataset.training:
|
for data_dir in cfg.dataset.training:
|
||||||
paths.extend(data_dir.rglob("*.jpg"))
|
train_paths.extend(data_dir.rglob("*.jpg"))
|
||||||
paths.extend(data_dir.rglob("*.png"))
|
train_paths.extend(data_dir.rglob("*.png"))
|
||||||
|
|
||||||
if len(paths) > 0:
|
if len(train_paths) > 0:
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
random.shuffle(paths)
|
random.shuffle(train_paths)
|
||||||
train_paths.extend(paths)
|
|
||||||
|
|
||||||
if len(cfg.dataset.validation) == 0:
|
for data_dir in cfg.dataset.validation:
|
||||||
val_len = math.floor(len(train_paths) * val_ratio)
|
val_paths.extend(data_dir.rglob("*.jpg"))
|
||||||
train_len = math.floor(len(train_paths) * (1 - val_ratio))
|
val_paths.extend(data_dir.rglob("*.png"))
|
||||||
|
|
||||||
val_paths = train_paths[:-val_len]
|
if len(val_paths) > 0:
|
||||||
train_paths = train_paths[:train_len]
|
random.seed(0)
|
||||||
else:
|
random.shuffle(val_paths)
|
||||||
paths = []
|
|
||||||
|
|
||||||
for data_dir in cfg.dataset.validation:
|
|
||||||
paths.extend(data_dir.rglob("*.jpg"))
|
|
||||||
paths.extend(data_dir.rglob("*.png"))
|
|
||||||
|
|
||||||
if len(paths) > 0:
|
|
||||||
random.seed(0)
|
|
||||||
random.shuffle(paths)
|
|
||||||
val_paths.extend(paths)
|
|
||||||
|
|
||||||
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
train_paths, val_paths = map(sorted, [train_paths, val_paths])
|
||||||
|
|
||||||
|
@ -596,3 +584,6 @@ if __name__ == "__main__":
|
||||||
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
||||||
break
|
break
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
...
|
|
@ -35,7 +35,7 @@ class Model(nn.Module):
|
||||||
|
|
||||||
self.n_tokens = n_tokens
|
self.n_tokens = n_tokens
|
||||||
self.n_len = n_len + 2 # start/stop tokens
|
self.n_len = n_len + 2 # start/stop tokens
|
||||||
self.d_model = d_model
|
# self.d_model = d_model
|
||||||
self.d_resnet = d_resnet
|
self.d_resnet = d_resnet
|
||||||
|
|
||||||
ResNet = resnet18
|
ResNet = resnet18
|
||||||
|
@ -56,7 +56,7 @@ class Model(nn.Module):
|
||||||
ResNet = resnet152
|
ResNet = resnet152
|
||||||
|
|
||||||
self.resnet = ResNet(pretrained=False)
|
self.resnet = ResNet(pretrained=False)
|
||||||
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
|
self.resnet.fc = nn.Linear( self.resnet.fc.in_features, self.n_tokens * self.n_len )
|
||||||
|
|
||||||
self.accuracy_metric = MulticlassAccuracy(
|
self.accuracy_metric = MulticlassAccuracy(
|
||||||
n_tokens,
|
n_tokens,
|
||||||
|
@ -96,10 +96,13 @@ class Model(nn.Module):
|
||||||
nll = sum( loss ) / len( loss ),
|
nll = sum( loss ) / len( loss ),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stats = dict(
|
try:
|
||||||
acc = self.accuracy_metric( pred, labels ),
|
self.stats = dict(
|
||||||
precision = self.precision_metric( pred, labels ),
|
acc = self.accuracy_metric( pred, labels ),
|
||||||
)
|
precision = self.precision_metric( pred, labels ),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
|
answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user