fixed not being able to use other resnets as a base

This commit is contained in:
mrq 2024-09-22 22:07:19 -05:00
parent 6eb65326c3
commit 16ad4fa1c9
3 changed files with 33 additions and 34 deletions

View File

@ -35,6 +35,7 @@ def main():
parser.add_argument("--base64", type=str)
parser.add_argument("--write", type=Path)
parser.add_argument("--temp", type=float, default=1.0)
parser.add_argument("--limit", type=int, default=0)
args, unknown = parser.parse_known_args()
images = []
@ -43,9 +44,13 @@ def main():
for p in args.path.rglob("./*.jpg"):
image = Image.open(p).convert('RGB')
images.append(image)
if args.limit and len(images) >= args.limit:
break
for p in args.path.rglob("./*.png"):
image = Image.open(p).convert('RGB')
images.append(image)
if args.limit and len(images) >= args.limit:
break
else:
image = Image.open(args.path).convert('RGB')
images.append(image)

View File

@ -158,13 +158,13 @@ class Dataset(_Dataset):
def __getitem__(self, index):
path = self.paths[index]
tokens = tokenize( path.stem.upper() )
text = torch.tensor( tokens ).to(dtype=torch.uint8)
text = path.stem.upper()
image = Image.open(path).convert('RGB')
width, height = image.size
image = self.transform(image).to(dtype=self.image_dtype) # resnet has to be RGB
text = torch.tensor( tokenize( text ) ).to(dtype=torch.uint8)
image = self.transform(image).to(dtype=self.image_dtype)
return dict(
index=index,
@ -215,36 +215,24 @@ def _create_dataloader(dataset, training):
)
def _load_train_val_paths( val_ratio=0.1 ):
paths = []
train_paths = []
val_paths = []
for data_dir in cfg.dataset.training:
paths.extend(data_dir.rglob("*.jpg"))
paths.extend(data_dir.rglob("*.png"))
train_paths.extend(data_dir.rglob("*.jpg"))
train_paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
if len(train_paths) > 0:
random.seed(0)
random.shuffle(paths)
train_paths.extend(paths)
random.shuffle(train_paths)
if len(cfg.dataset.validation) == 0:
val_len = math.floor(len(train_paths) * val_ratio)
train_len = math.floor(len(train_paths) * (1 - val_ratio))
for data_dir in cfg.dataset.validation:
val_paths.extend(data_dir.rglob("*.jpg"))
val_paths.extend(data_dir.rglob("*.png"))
val_paths = train_paths[:-val_len]
train_paths = train_paths[:train_len]
else:
paths = []
for data_dir in cfg.dataset.validation:
paths.extend(data_dir.rglob("*.jpg"))
paths.extend(data_dir.rglob("*.png"))
if len(paths) > 0:
random.seed(0)
random.shuffle(paths)
val_paths.extend(paths)
if len(val_paths) > 0:
random.seed(0)
random.shuffle(val_paths)
train_paths, val_paths = map(sorted, [train_paths, val_paths])
@ -596,3 +584,6 @@ if __name__ == "__main__":
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
break
"""
if __name__ == "__main__":
...

View File

@ -35,7 +35,7 @@ class Model(nn.Module):
self.n_tokens = n_tokens
self.n_len = n_len + 2 # start/stop tokens
self.d_model = d_model
# self.d_model = d_model
self.d_resnet = d_resnet
ResNet = resnet18
@ -56,7 +56,7 @@ class Model(nn.Module):
ResNet = resnet152
self.resnet = ResNet(pretrained=False)
self.resnet.fc = nn.Linear( self.d_model, self.n_tokens * self.n_len )
self.resnet.fc = nn.Linear( self.resnet.fc.in_features, self.n_tokens * self.n_len )
self.accuracy_metric = MulticlassAccuracy(
n_tokens,
@ -96,10 +96,13 @@ class Model(nn.Module):
nll = sum( loss ) / len( loss ),
)
self.stats = dict(
acc = self.accuracy_metric( pred, labels ),
precision = self.precision_metric( pred, labels ),
)
try:
self.stats = dict(
acc = self.accuracy_metric( pred, labels ),
precision = self.precision_metric( pred, labels ),
)
except Exception as e:
pass
answer = [ "".join([ self.symmap[f'{x.item()}'] for x in t ]) for t in pred ]