Skip to content

Commit

Permalink
Added support for parallel tile downloads and control of cache (#217)
Browse files Browse the repository at this point in the history
* Added support for parallel tile downloads in the bounds2raster and bounds2img functions

* Fixed a memory bug when using threads to download tiles in parallel while also caching the downloaded tiles. The solution was to use parallel processes instead of threads.

* Changed name of num_parallel_tile_downloads to n_connections, and set it to default value of 1. Added different n_connections values when testing the bounds2img() function.

* Moved max_connections to be a parameter (was hardcoded before). Also added a parameter to disable caching, which is useful in resource constrained environments when using parallel connections for download.

* Removed max_connections and updated docstrings for n_connections and disable_cache

* Changed disable_cache=False to use_cache=True in bounds2raster() and bounds2img() function parameters to avoid using double negative
  • Loading branch information
JacobJeppesen committed Jun 12, 2023
1 parent 0c8c9ce commit 4ef4a67
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 25 deletions.
57 changes: 45 additions & 12 deletions contextily/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import rasterio as rio
from PIL import Image
from joblib import Memory as _Memory
from joblib import Parallel, delayed
from rasterio.transform import from_origin
from rasterio.io import MemoryFile
from rasterio.vrt import WarpedVRT
Expand Down Expand Up @@ -74,6 +75,8 @@ def bounds2raster(
ll=False,
wait=0,
max_retries=2,
n_connections=1,
use_cache=True,
):
"""
Take bounding box and zoom, and write tiles into a raster file in
Expand Down Expand Up @@ -113,6 +116,17 @@ def bounds2raster(
[Optional. Default: 2]
total number of rejected requests allowed before contextily
will stop trying to fetch more tiles from a rate-limited API.
n_connections: int
[Optional. Default: 1]
Number of connections for downloading tiles in parallel. Be careful not to overload the tile server and to check
the tile provider's terms of use before increasing this value. E.g., OpenStreetMap has a max. value of 2
(https://operations.osmfoundation.org/policies/tiles/). If allowed to download in parallel, a recommended
value for n_connections is 16, and should never be larger than 64.
use_cache: bool
[Optional. Default: True]
If False, caching of the downloaded tiles will be disabled. This can be useful in resource constrained
environments, especially when using n_connections > 1, or when a tile provider's terms of use don't allow
caching.
Returns
-------
Expand All @@ -126,7 +140,9 @@ def bounds2raster(
w, s = _sm2ll(w, s)
e, n = _sm2ll(e, n)
# Download
Z, ext = bounds2img(w, s, e, n, zoom=zoom, source=source, ll=True)
Z, ext = bounds2img(w, s, e, n, zoom=zoom, source=source, ll=True, n_connections=n_connections,
use_cache=use_cache)

# Write
# ---
h, w, b = Z.shape
Expand Down Expand Up @@ -155,7 +171,7 @@ def bounds2raster(


def bounds2img(
w, s, e, n, zoom="auto", source=None, ll=False, wait=0, max_retries=2
w, s, e, n, zoom="auto", source=None, ll=False, wait=0, max_retries=2, n_connections=1, use_cache=True
):
"""
Take bounding box and zoom and return an image with all the tiles
Expand Down Expand Up @@ -193,6 +209,17 @@ def bounds2img(
[Optional. Default: 2]
total number of rejected requests allowed before contextily
will stop trying to fetch more tiles from a rate-limited API.
n_connections: int
[Optional. Default: 1]
Number of connections for downloading tiles in parallel. Be careful not to overload the tile server and to check
the tile provider's terms of use before increasing this value. E.g., OpenStreetMap has a max. value of 2
(https://operations.osmfoundation.org/policies/tiles/). If allowed to download in parallel, a recommended
value for n_connections is 16, and should never be larger than 64.
use_cache: bool
[Optional. Default: True]
If False, caching of the downloaded tiles will be disabled. This can be useful in resource constrained
environments, especially when using n_connections > 1, or when a tile provider's terms of use don't allow
caching.
Returns
-------
Expand All @@ -213,15 +240,22 @@ def bounds2img(
if auto_zoom:
zoom = _calculate_zoom(w, s, e, n)
zoom = _validate_zoom(zoom, provider, auto=auto_zoom)
# download and merge tiles
tiles = []
arrays = []
for t in mt.tiles(w, s, e, n, [zoom]):
x, y, z = t.x, t.y, t.z
tile_url = provider.build_url(x=x, y=y, z=z)
image = _fetch_tile(tile_url, wait, max_retries)
tiles.append(t)
arrays.append(image)
# create list of tiles to download
tiles = list(mt.tiles(w, s, e, n, [zoom]))
tile_urls = [provider.build_url(x=tile.x, y=tile.y, z=tile.z) for tile in tiles]
# download tiles
if n_connections < 1 or not isinstance(n_connections, int):
raise ValueError(
f"n_connections must be a positive integer value."
)
# Use threads for a single connection to avoid the overhead of spawning a process. Use processes for multiple
# connections if caching is enabled, as threads lead to memory issues when used in combination with the joblib
# memory caching (used for the _fetch_tile() function).
preferred_backend = "threads" if (n_connections == 1 or not use_cache) else "processes"
fetch_tile_fn = memory.cache(_fetch_tile) if use_cache else _fetch_tile
arrays = Parallel(n_jobs=n_connections, prefer=preferred_backend)(
delayed(fetch_tile_fn)(tile_url, wait, max_retries) for tile_url in tile_urls)
# merge downloaded tiles
merged, extent = _merge_tiles(tiles, arrays)
# lon/lat extent --> Spheric Mercator
west, south, east, north = extent
Expand All @@ -247,7 +281,6 @@ def _process_source(source):
return provider


@memory.cache
def _fetch_tile(tile_url, wait, max_retries):
request = _retryer(tile_url, wait, max_retries)
with io.BytesIO(request.content) as image_stream:
Expand Down
31 changes: 18 additions & 13 deletions tests/test_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,31 @@ def test_bounds2raster():
assert_array_almost_equal(list(rtr.bounds), rtr_bounds)


@pytest.mark.parametrize("n_connections", [0, 1, 16])
@pytest.mark.network
def test_bounds2img():
def test_bounds2img(n_connections):
w, s, e, n = (
-106.6495132446289,
25.845197677612305,
-93.50721740722656,
36.49387741088867,
)
img, ext = ctx.bounds2img(w, s, e, n, zoom=4, ll=True)
solu = (
-12523442.714243276,
-10018754.171394622,
2504688.5428486555,
5009377.085697309,
)
for i, j in zip(ext, solu):
assert round(i - j, TOL) == 0
assert img[100, 100, :].tolist() == [230, 229, 188, 255]
assert img[100, 200, :].tolist() == [156, 180, 131, 255]
assert img[200, 100, :].tolist() == [230, 225, 189, 255]
if n_connections in [1, 16]: # valid number of connections (test single and multiple connections)
img, ext = ctx.bounds2img(w, s, e, n, zoom=4, ll=True, n_connections=n_connections)
solu = (
-12523442.714243276,
-10018754.171394622,
2504688.5428486555,
5009377.085697309,
)
for i, j in zip(ext, solu):
assert round(i - j, TOL) == 0
assert img[100, 100, :].tolist() == [230, 229, 188, 255]
assert img[100, 200, :].tolist() == [156, 180, 131, 255]
assert img[200, 100, :].tolist() == [230, 225, 189, 255]
elif n_connections == 0: # no connections should raise an error
with pytest.raises(ValueError):
img, ext = ctx.bounds2img(w, s, e, n, zoom=4, ll=True, n_connections=n_connections)


@pytest.mark.network
Expand Down

0 comments on commit 4ef4a67

Please sign in to comment.