-
Notifications
You must be signed in to change notification settings - Fork 12
Restrict Vault token exchange to specific hosts; improve auth errors; (Issue #19) #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,6 +1,7 @@ | ||||||
| import json | ||||||
| import os | ||||||
| from typing import List | ||||||
| from urllib.parse import urlparse | ||||||
|
|
||||||
| import requests | ||||||
| from SPARQLWrapper import JSON, SPARQLWrapper | ||||||
|
|
@@ -12,6 +13,18 @@ | |||||
| ) | ||||||
|
|
||||||
|
|
||||||
| # Hosts that require Vault token based authentication. Central source of truth. | ||||||
| VAULT_REQUIRED_HOSTS = { | ||||||
| "data.dbpedia.io", | ||||||
| "data.dev.dbpedia.link", | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
| class DownloadAuthError(Exception): | ||||||
| """Raised when an authorization problem occurs during download.""" | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
| def _download_file( | ||||||
| url, | ||||||
| localDir, | ||||||
|
|
@@ -52,13 +65,23 @@ def _download_file( | |||||
| os.makedirs(dirpath, exist_ok=True) # Create the necessary directories | ||||||
| # --- 1. Get redirect URL by requesting HEAD --- | ||||||
| headers = {} | ||||||
|
|
||||||
| # Determine hostname early and fail fast if this host requires Vault token. | ||||||
| # This prevents confusing 401/403 errors later and tells the user exactly | ||||||
| # what to do (provide --vault-token). | ||||||
| parsed = urlparse(url) | ||||||
| host = parsed.hostname | ||||||
| if host in VAULT_REQUIRED_HOSTS and not vault_token_file: | ||||||
| raise DownloadAuthError( | ||||||
| f"Vault token required for host '{host}', but no token was provided. Please use --vault-token." | ||||||
| ) | ||||||
|
|
||||||
| # --- 1a. public databus --- | ||||||
| response = requests.head(url, timeout=30) | ||||||
| # --- 1b. Databus API key required --- | ||||||
| if response.status_code == 401: | ||||||
| # print(f"API key required for {url}") | ||||||
| if not databus_key: | ||||||
| raise ValueError("Databus API key not given for protected download") | ||||||
| raise DownloadAuthError("Databus API key not given for protected download") | ||||||
|
|
||||||
| headers = {"X-API-KEY": databus_key} | ||||||
| response = requests.head(url, headers=headers, timeout=30) | ||||||
|
|
@@ -81,25 +104,54 @@ def _download_file( | |||||
| response = requests.get( | ||||||
| url, headers=headers, stream=True, allow_redirects=True, timeout=30 | ||||||
| ) | ||||||
| www = response.headers.get( | ||||||
| "WWW-Authenticate", "" | ||||||
| ) # Check if authentication is required | ||||||
| www = response.headers.get("WWW-Authenticate", "") # Check if authentication is required | ||||||
|
|
||||||
| # --- 3. If redirected to authentication 401 Unauthorized, get Vault token and retry --- | ||||||
| # --- 3. Handle authentication responses --- | ||||||
| # 3a. Server requests Bearer auth. Only attempt token exchange for hosts | ||||||
| # we explicitly consider Vault-protected (VAULT_REQUIRED_HOSTS). This avoids | ||||||
| # sending tokens to unrelated hosts and makes auth behavior predictable. | ||||||
| if response.status_code == 401 and "bearer" in www.lower(): | ||||||
| print(f"Authentication required for {url}") | ||||||
| if not (vault_token_file): | ||||||
| raise ValueError("Vault token file not given for protected download") | ||||||
| # If host is not configured for Vault, do not attempt token exchange. | ||||||
| if host not in VAULT_REQUIRED_HOSTS: | ||||||
| raise DownloadAuthError( | ||||||
| "Server requests Bearer authentication but this host is not configured for Vault token exchange." | ||||||
| " Try providing a databus API key with --databus-key or contact your administrator." | ||||||
| ) | ||||||
|
|
||||||
| # Host requires Vault; ensure token file provided. | ||||||
| if not vault_token_file: | ||||||
| raise DownloadAuthError( | ||||||
| f"Vault token required for host '{host}', but no token was provided. Please use --vault-token." | ||||||
| ) | ||||||
|
|
||||||
| # --- 3a. Fetch Vault token --- | ||||||
| # TODO: cache token | ||||||
| # --- 3b. Fetch Vault token and retry --- | ||||||
| # Token exchange is potentially sensitive and should only be performed | ||||||
| # for known hosts. __get_vault_access__ handles reading the refresh | ||||||
| # token and exchanging it; errors are translated to DownloadAuthError | ||||||
| # for user-friendly CLI output. | ||||||
| vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Pass final URL to token exchange function Line 139 passes the The 🔎 Proposed fix- vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id)
+ vault_token = __get_vault_access__(response.url, vault_token_file, auth_url, client_id)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||
| headers["Authorization"] = f"Bearer {vault_token}" | ||||||
| headers.pop("Accept-Encoding") | ||||||
| headers.pop("Accept-Encoding", None) | ||||||
|
|
||||||
| # --- 3b. Retry with token --- | ||||||
| # Retry with token | ||||||
| response = requests.get(url, headers=headers, stream=True, timeout=30) | ||||||
|
|
||||||
| # Map common auth failures to friendly messages | ||||||
| if response.status_code == 401: | ||||||
| raise DownloadAuthError("Vault token is invalid or expired. Please generate a new token.") | ||||||
| if response.status_code == 403: | ||||||
| raise DownloadAuthError("Vault token is valid but has insufficient permissions to access this file.") | ||||||
|
|
||||||
| # 3c. Generic forbidden without Bearer challenge | ||||||
| if response.status_code == 403: | ||||||
| raise DownloadAuthError("Access forbidden: your token or API key does not have permission to download this file.") | ||||||
|
|
||||||
| # 3d. Generic unauthorized without Bearer | ||||||
| if response.status_code == 401: | ||||||
| raise DownloadAuthError( | ||||||
| "Unauthorized: access denied. Check your --databus-key or --vault-token settings." | ||||||
| ) | ||||||
|
|
||||||
| try: | ||||||
| response.raise_for_status() # Raise if still failing | ||||||
| except requests.exceptions.HTTPError as e: | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| import sys | ||
| import types | ||
|
|
||
| # Provide a lightweight fake SPARQLWrapper module for tests when not installed. | ||
| if "SPARQLWrapper" not in sys.modules: | ||
| mod = types.ModuleType("SPARQLWrapper") | ||
| mod.JSON = None | ||
|
|
||
| class DummySPARQL: | ||
| def __init__(self, *args, **kwargs): | ||
| pass | ||
|
|
||
| def setQuery(self, q): | ||
| self._q = q | ||
|
|
||
| def setReturnFormat(self, f): | ||
| self._fmt = f | ||
|
|
||
| def setCustomHttpHeaders(self, h): | ||
| self._headers = h | ||
|
|
||
| def query(self): | ||
| class R: | ||
| def convert(self): | ||
| return {"results": {"bindings": []}} | ||
|
|
||
| return R() | ||
|
|
||
| mod.SPARQLWrapper = DummySPARQL | ||
| sys.modules["SPARQLWrapper"] = mod |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| from unittest.mock import Mock, patch | ||
|
|
||
| import pytest | ||
|
|
||
| import requests | ||
|
|
||
| import databusclient.api.download as dl | ||
|
|
||
| from databusclient.api.download import VAULT_REQUIRED_HOSTS, DownloadAuthError | ||
|
|
||
|
|
||
| def make_response(status=200, headers=None, content=b""): | ||
| headers = headers or {} | ||
| mock = Mock() | ||
| mock.status_code = status | ||
| mock.headers = headers | ||
| mock.content = content | ||
|
|
||
| def iter_content(chunk_size): | ||
| if content: | ||
| yield content | ||
| else: | ||
| return | ||
|
|
||
| mock.iter_content = lambda chunk: iter(iter_content(chunk)) | ||
|
|
||
| def raise_for_status(): | ||
| if mock.status_code >= 400: | ||
| raise requests.exceptions.HTTPError() | ||
|
|
||
| mock.raise_for_status = raise_for_status | ||
| return mock | ||
|
|
||
|
|
||
| def test_vault_host_no_token_raises(): | ||
| vault_host = next(iter(VAULT_REQUIRED_HOSTS)) | ||
| url = f"https://{vault_host}/some/protected/file.ttl" | ||
|
|
||
| with pytest.raises(DownloadAuthError) as exc: | ||
| dl._download_file(url, localDir='.', vault_token_file=None) | ||
|
|
||
| assert "Vault token required" in str(exc.value) | ||
|
|
||
|
|
||
| def test_non_vault_host_no_token_allows_download(monkeypatch): | ||
| url = "https://example.com/public/file.txt" | ||
|
|
||
| resp_head = make_response(status=200, headers={}) | ||
| resp_get = make_response(status=200, headers={"content-length": "0"}, content=b"") | ||
|
|
||
| with patch("requests.head", return_value=resp_head), patch( | ||
| "requests.get", return_value=resp_get | ||
| ): | ||
| # should not raise | ||
| dl._download_file(url, localDir='.', vault_token_file=None) | ||
|
|
||
|
|
||
| def test_401_after_token_exchange_reports_invalid_token(monkeypatch): | ||
| vault_host = next(iter(VAULT_REQUIRED_HOSTS)) | ||
| url = f"https://{vault_host}/protected/file.ttl" | ||
|
|
||
| # initial head and get -> 401 with Bearer | ||
| resp_head = make_response(status=200, headers={}) | ||
| resp_401 = make_response(status=401, headers={"WWW-Authenticate": "Bearer realm=\"auth\""}) | ||
|
|
||
| # after retry with token -> still 401 | ||
| resp_401_retry = make_response(status=401, headers={}) | ||
|
|
||
| # Mock requests.get side effects: first 401 (challenge), then 401 after token | ||
| get_side_effects = [resp_401, resp_401_retry] | ||
|
|
||
| # Mock token exchange responses | ||
| post_resp_1 = Mock() | ||
| post_resp_1.json.return_value = {"access_token": "ACCESS"} | ||
| post_resp_2 = Mock() | ||
| post_resp_2.json.return_value = {"access_token": "VAULT"} | ||
|
|
||
| with patch("requests.head", return_value=resp_head), patch( | ||
| "requests.get", side_effect=get_side_effects | ||
| ), patch("requests.post", side_effect=[post_resp_1, post_resp_2]): | ||
| # set REFRESH_TOKEN so __get_vault_access__ doesn't try to open a file | ||
| monkeypatch.setenv("REFRESH_TOKEN", "x" * 90) | ||
|
|
||
| with pytest.raises(DownloadAuthError) as exc: | ||
| dl._download_file(url, localDir='.', vault_token_file="/does/not/matter") | ||
|
|
||
| assert "invalid or expired" in str(exc.value) | ||
|
|
||
|
|
||
| def test_403_reports_insufficient_permissions(): | ||
| vault_host = next(iter(VAULT_REQUIRED_HOSTS)) | ||
| url = f"https://{vault_host}/protected/file.ttl" | ||
|
|
||
| resp_head = make_response(status=200, headers={}) | ||
| resp_403 = make_response(status=403, headers={}) | ||
|
|
||
| with patch("requests.head", return_value=resp_head), patch( | ||
| "requests.get", return_value=resp_403 | ||
| ): | ||
| # provide a token path so early check does not block | ||
| with pytest.raises(DownloadAuthError) as exc: | ||
| dl._download_file(url, localDir='.', vault_token_file="/some/token/file") | ||
|
|
||
| assert "permission" in str(exc.value) or "forbidden" in str(exc.value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The client gets an databus uri as an argument which will be downloaded, e.g.:
databusclient download https://databus.dev.dbpedia.link/fhofer/live-fusion-kg-snapshot/geo-coordinates/2025-11-11/geo-coordinates_graph\=dbpedia-io_partition\=wgs-lat.ttl.gz --vault-token vault-token.datThe databus sends a redirect uri back to the client which represents the actual download location. For the example above this is:
So you have to check the redirect uri if authentication is required not the databus uri.
You can read here about how to get an access token so you can test with one of the KGs that need a vault token authentication, e.g. live-fusion-kg