Skip to content

Commit

Permalink
Improve: SQLite download process
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Apr 1, 2024
1 parent aff6293 commit 47f94ef
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 27 deletions.
6 changes: 5 additions & 1 deletion python/scripts/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import usearch


found_sqlite_path = usearch.sqlite
try:
found_sqlite_path = usearch.sqlite_path()
except FileNotFoundError:
found_sqlite_path = None

if found_sqlite_path is None:
pytest.skip(reason="Can't find an SQLite installation", allow_module_level=True)

Expand Down
57 changes: 32 additions & 25 deletions python/usearch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@ class BinaryManager:
def __init__(self, version: Optional[str] = None):
if version is None:
version = __version__
self.version = version

def sqlite_download_url(self) -> str:
"""
Constructs a download URL for the `usearch_sqlite` binary based on the operating system, architecture, and version.
Args:
version (str): The version of the binary to download.
self.version = version or __version__
self.download_dir = self.determine_download_dir()

@staticmethod
def determine_download_dir():
# Check if running within a virtual environment
virtual_env = os.getenv("VIRTUAL_ENV")
if virtual_env:
# Use a subdirectory within the virtual environment for binaries
return os.path.join(virtual_env, "bin", "usearch_binaries")
else:
# Fallback to a directory in the user's home folder
home_dir = os.path.expanduser("~")
return os.path.join(home_dir, ".usearch", "binaries")

Returns:
A string representing the download URL.
"""
def sqlite_file_name(self) -> str:
version = self.version
base_url = "https://github.com/unum-cloud/usearch/releases/download"
os_map = {"Linux": "linux", "Windows": "windows", "Darwin": "macos"}
arch_map = {
"x86_64": "amd64" if platform.system() != "Darwin" else "x86_64",
Expand All @@ -47,6 +50,12 @@ def sqlite_download_url(self) -> str:
arch_part = arch_map.get(arch, "")
extension = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"}.get(platform.system(), "")
filename = f"usearch_sqlite_{os_part}_{arch_part}_{version}.{extension}"
return filename

def sqlite_download_url(self) -> str:
version = self.version
filename = self.sqlite_file_name()
base_url = "https://github.com/unum-cloud/usearch/releases/download"
url = f"{base_url}/v{version}/{filename}"
return url

Expand All @@ -66,7 +75,6 @@ def download_binary(self, url: str, dest_folder: str) -> str:
urllib.request.urlretrieve(url, dest_path)
return dest_path

@property
def sqlite_found_or_downloaded(self) -> Optional[str]:
"""
Attempts to locate the pre-installed `usearch_sqlite` binary.
Expand All @@ -89,20 +97,16 @@ def sqlite_found_or_downloaded(self) -> Optional[str]:
return os.path.join(root, file).removesuffix(file_extension)

# Check a temporary directory (assuming the binary might be downloaded from a GitHub release)
temp_dir = tempfile.gettempdir()
for root, _, files in os.walk(temp_dir):
for file in files:
if file.endswith(file_extension) and "usearch_sqlite" in file:
return os.path.join(root, file).removesuffix(file_extension)
local_path = os.path.join(self.download_dir, self.sqlite_file_name())
if os.path.exists(local_path):
return local_path.removesuffix(file_extension)

# If not found locally, warn the user and download from GitHub
temp_dir = tempfile.gettempdir()
warnings.warn("Will download `usearch_sqlite` binary from GitHub.", UserWarning)

# If the download fails due to HTTPError (e.g., 404 Not Found), like a missing lib version
try:
binary_path = self.download_binary(self.sqlite_download_url(), temp_dir)
binary_path = self.download_binary(self.sqlite_download_url(), self.download_dir)
except HTTPError as e:
# If the download fails due to HTTPError (e.g., 404 Not Found), like a missing lib version
if e.code == 404:
warnings.warn(f"Download failed: {e.url} could not be found.", UserWarning)
else:
Expand All @@ -117,6 +121,9 @@ def sqlite_found_or_downloaded(self) -> Optional[str]:
return None


# Use the function to set the `sqlite` computed property
binary_manager = BinaryManager()
sqlite = binary_manager.sqlite_found_or_downloaded
def sqlite_path(version: str = None) -> str:
manager = BinaryManager(version=version)
result = manager.sqlite_found_or_downloaded()
if result is None:
raise FileNotFoundError("Failed to find or download `usearch_sqlite` binary.")
return result
2 changes: 1 addition & 1 deletion sqlite/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import usearch

conn = sqlite3.connect(":memory:")
conn.enable_load_extension(True)
conn.load_extension(usearch.sqlite)
conn.load_extension(usearch.sqlite_path())
```

Afterwards, the following script should work fine.
Expand Down

0 comments on commit 47f94ef

Please sign in to comment.