Skip to content
This repository has been archived by the owner on Jun 21, 2022. It is now read-only.

Commit

Permalink
Added last DiskCache feature: __contains__
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Oct 18, 2017
1 parent 881fb7a commit 1ecc21b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
15 changes: 15 additions & 0 deletions tests/test_diskcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,22 +153,37 @@ def test_eviction(self):
cache["a"] = numpy.ones(25, dtype=numpy.float64)
self.assertEqual(cache.numbytes, 80360)
self.assertEqual(list(cache.keys()), ["a"])
self.assertTrue("a" in cache)

cache["b"] = numpy.ones(25, dtype=numpy.float64)
self.assertEqual(cache.numbytes, 80640)
self.assertEqual(list(cache.keys()), ["a", "b"])
self.assertTrue("a" in cache)
self.assertTrue("b" in cache)

cache["c"] = numpy.ones(25, dtype=numpy.float64)
self.assertEqual(cache.numbytes, 80920)
self.assertEqual(list(cache.keys()), ["a", "b", "c"])
self.assertTrue("a" in cache)
self.assertTrue("b" in cache)
self.assertTrue("c" in cache)

cache["d"] = numpy.ones(25, dtype=numpy.float64)
self.assertEqual(cache.numbytes, 81200)
self.assertEqual(list(cache.keys()), ["a", "b", "c", "d"])
self.assertTrue("a" in cache)
self.assertTrue("b" in cache)
self.assertTrue("c" in cache)
self.assertTrue("d" in cache)

cache["e"] = numpy.ones(25, dtype=numpy.float64)
self.assertEqual(cache.numbytes, 81200)
self.assertEqual(list(cache.keys()), ["b", "c", "d", "e"])
self.assertTrue("a" not in cache)
self.assertTrue("b" in cache)
self.assertTrue("c" in cache)
self.assertTrue("d" in cache)
self.assertTrue("e" in cache)

cache["f"] = numpy.ones(25, dtype=numpy.float64)
self.assertEqual(cache.numbytes, 81200)
Expand Down
29 changes: 23 additions & 6 deletions uproot/cache/diskcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,22 @@ def destroy(self):
def refresh_config(self):
self.config.__dict__.update(json.load(os.path.join(self.directory, self.CONFIG_FILE)))

def __contains__(self, name):
if not isinstance(name, bytes) and hasattr(name, "encode"):
name = name.encode("utf-8")
if not isinstance(name, bytes):
raise TypeError("keys must be strings, not {0}".format(type(name)))

self._lockstate()
try:
for num, n in self._walkorder(os.path.join(self.directory, self.ORDER_DIR), reverse=True):
if name == n:
return True
return False

finally:
self._unlockstate()

def promote(self, name):
if not isinstance(name, bytes) and hasattr(name, "encode"):
name = name.encode("utf-8")
Expand Down Expand Up @@ -442,7 +458,7 @@ def keys(self):
self._lockstate()
try:
for num, name in self._walkorder(os.path.join(self.directory, self.ORDER_DIR)):
yield name
yield name.decode("utf-8")
finally:
self._unlockstate()

Expand Down Expand Up @@ -471,7 +487,7 @@ def cleanup():
return cleanup

for name, linkpath in linkpaths:
yield name, self.read(linkpath, make_cleanup(linkpath))
yield name.decode("utf-8"), self.read(linkpath, make_cleanup(linkpath))

def values(self):
for name, obj in self.items():
Expand Down Expand Up @@ -661,21 +677,22 @@ def _newpath(self, name):
# return new path
return os.path.join(path, self._formatter.format(num) + self.config.delimiter + urlquote(name, safe=""))

def _walkorder(self, path):
def _walkorder(self, path, sort=True, reverse=False):
assert self._lock is not None
items = os.listdir(path)
items.sort()
if sort:
items.sort(reverse=reverse)

for fn in items:
subpath = os.path.join(path, fn)
if os.path.isdir(subpath):
for x in self._walkorder(subpath):
for x in self._walkorder(subpath, sort=sort, reverse=reverse):
yield x
else:
i = fn.index(self.config.delimiter)
num = int(fn[:i])
name = urlunquote(fn[i + 1:])
yield num, name
yield num, name.encode("utf-8")

def _evict(self, path, top):
assert self._lock is not None
Expand Down
2 changes: 1 addition & 1 deletion uproot/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import re

__version__ = "1.6.0"
__version__ = "1.6.1"
version = __version__
version_info = tuple(re.split(r"[-\.]", __version__))

Expand Down

0 comments on commit 1ecc21b

Please sign in to comment.