This repository has been archived by the owner on Dec 5, 2023. It is now read-only.
/
evaluation.py
218 lines (175 loc) · 7.6 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import json
from collections import defaultdict
from pathlib import Path
from typing import Callable, Iterator
import tensorflow_datasets as tfds
import torch
from PIL import Image
from torch.utils.data import Dataset, IterDataPipe
from torch.utils.data.datapipes.iter import ShardingFilter
from torchvision.datasets import ImageFolder
class CocoCaptions(Dataset):
"""
COCO captions dataset. Homepage: https://cocodataset.org
"""
def __init__(self, root: str | Path, split: str, transform: Callable | None = None):
"""
Args:
root: Dataset root directory. It should contain image directories
named `train2017` and `val2017`, and a separate directory
containing caption annotations JSON file.
split: Name of 2017 split to load, one of `{train, val}`.
transform: A function/transform that takes in an PIL image and
returns a transformed version.
"""
super().__init__()
self.root = Path(root)
self.split = split
self.transform = transform
# Read annotations for the given split.
json_path = self.root / "annotations" / f"captions_{split}2017.json"
coco_json = json.load(open(json_path))
# Build a temporary mapping between image ID and captions.
image_id_to_anns = defaultdict(list)
for ann in coco_json["annotations"]:
image_id_to_anns[ann["image_id"]].append(ann)
# Convert the above mapping to list of tuples formatted as:
# `(image_id, image_path, list[caption_ids], list[caption])`.
self.samples = [
(
image_id,
self.root / f"{split}2017" / f"{image_id:0>12d}.jpg",
[ann["id"] for ann in anns],
[ann["caption"] for ann in anns],
)
for image_id, anns in image_id_to_anns.items()
]
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> dict:
image_id, image_path, caption_ids, captions = self.samples[idx]
image = Image.open(image_path).convert("RGB")
if self.transform is not None:
image = self.transform(image)
return {
"image_id": image_id,
"caption_ids": caption_ids,
"image": image,
"captions": captions,
}
class Flickr30kCaptions(CocoCaptions):
"""
Flickr30K captions dataset.
Karpathy split JSON can be downloaded from this webpage:
https://cs.stanford.edu/people/karpathy/deepimagesent/
"""
def __init__(self, root: str | Path, split: str, transform: Callable | None = None):
"""
Args:
root: Dataset root directory. It should contain a JSON file named
`dataset_flickr30k.json` containing Karpathy splits, and a
directory named `flickr30k_images` with all images (~31K).
split: Name of split to load, one of `{train, val, test}`.
transform: A function/transform that takes in an PIL image and
returns a transformed version.
"""
self.root = Path(root)
self.split = split
self.transform = transform
# Read annotations and keep only those belonging to specified split.
flickr_json = json.load(open(self.root / "dataset_flickr30k.json"))
# Convert the filtered list of tuples formatted as:
# `(image_id, image_path, list[caption_ids], list[caption])`.
# Only keep images that belong to required split.
self.samples = [
(
int(ann["filename"][:-4]),
self.root / "flickr30k_images" / ann["filename"],
ann["sentids"],
[entry["raw"] for entry in ann["sentences"]],
)
for ann in flickr_json["images"]
if ann["split"] == split
]
class ImageNet(ImageFolder):
"""
Lightweight wrapper over Torchvision `ImageFolder` to load ImageNet dataset.
"""
def __init__(self, root: str, split: str = "train", **kwargs):
super().__init__(str(Path(root) / split), **kwargs)
class TfdsWrapper(IterDataPipe):
"""
Minimal wrapper on `tensorflow-datasets` to serve `(image, label)`
tuples for image classification datasets. This wrapper enables a consistent
output format with dataset implementations from the Torchvision library.
"""
def __init__(
self,
name: str,
root: str | Path,
split: str,
transform: Callable | None = None,
):
"""
Args:
name: Name of a dataset supported by Tensorflow datasets. See
https://www.tensorflow.org/datasets/catalog/overview for details.
root: Dataset root directory. This is passed to the `data_dir`
argument of `tfds.load`. All datasets are auto-downloaded and
cached in this directory.
split: Which dataset split to load. This should be one of the official
splits for the given dataset.
transform: A function/transform that takes in an PIL image and
returns a transformed version.
"""
super().__init__()
self.name = name
self.split = split
self.transform = transform
dset = tfds.load(name, split=split, data_dir=root)
dset = tfds.as_numpy(dset)
# Record length of the dataset before further wrapping.
self._length = len(dset)
# Wrap the tensorflow dataset with `IterDataPipe` and apply sharding filter
# to avoid duplicates when multiple CPU workers are used in DataLoader.
self.dset = ShardingFilter(dset)
def __repr__(self):
return f"TfDatasetWrapper(name={self.name}, split={self.split})"
def __len__(self):
return self._length
def __iter__(self) -> Iterator[tuple[Image.Image, torch.Tensor]]:
for instance in self.dset:
# Convert numpy arrays: image (PIL.Image) and label (tensor).
# Handle special case with MNIST images.
if self.name == "mnist":
image = Image.fromarray(instance["image"][..., 0], mode="L")
else:
image = Image.fromarray(instance["image"])
image = image.convert("RGB")
label = torch.tensor(instance["label"])
if self.transform is not None:
image = self.transform(image)
yield image, label
class CLEVRCounts(TfdsWrapper):
"""
CLEVR-Counts image classification dataset. Counting the number of objects in
a scene is framed as a classification task. This task was included in the
Visual Task Adaptation Benchmark (VTAB), and used in CLIP evaluation suite.
"""
def __init__(self, root: str | Path, split: str, transform: Callable | None = None):
super().__init__("clevr", root, split, transform)
# Convert counts to contiguous labels.
self._labels = [10, 3, 4, 5, 6, 7, 8, 9]
def __iter__(self) -> Iterator[tuple[Image.Image, torch.Tensor]]:
for instance in self.dset:
image = Image.fromarray(instance["image"]).convert("RGB")
num_objects = len(instance["objects"]["color"])
label = torch.tensor(self._labels.index(num_objects))
if self.transform is not None:
image = self.transform(image)
yield image, label