/
extract_templates.py
43 lines (40 loc) · 1.58 KB
/
extract_templates.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import xml.etree.ElementTree as ET
import numpy as np
from matplotlib import pyplot as plt
def extract_templates(image_id):
"""
Extract templates
:param image_id: id of image to extract from
"""
image_dir = 'datasets/JPEGImages'
anno_dir = 'datasets/Annotations'
image_file = os.path.join(image_dir, '{}.jpg'.format(image_id))
anno_file = os.path.join(anno_dir, '{}.xml'.format(image_id))
assert os.path.exists(image_file), '{} not found.'.format(image_file)
assert os.path.exists(anno_file), '{} not found.'.format(anno_file)
anno_tree = ET.parse(anno_file)
objs = anno_tree.findall('object')
occurrences = {'waldo': 0, 'wenda': 0, 'wizard': 0}
image = np.asarray(plt.imread(image_file))
for key in occurrences.keys():
if not os.path.exists('templates/' + key + '/' + image_id):
os.makedirs('templates/' + key + '/' + image_id)
for idx, obj in enumerate(objs):
name = obj.find('name').text
bbox = obj.find('bndbox')
x1 = int(bbox.find('xmin').text)
y1 = int(bbox.find('ymin').text)
x2 = int(bbox.find('xmax').text)
y2 = int(bbox.find('ymax').text)
plt.imsave('templates/'
+ name
+ '/'
+ image_id
+ '/' + str(occurrences[name])
+ '.jpg',
image[y1:y2, x1:x2])
occurrences[name] += 1
for key in occurrences.keys():
if len(os.listdir('templates/' + key + '/' + image_id)) == 0:
os.rmdir('templates/' + key + '/' + image_id)