Skip to content

Commit

Permalink
Disable ability to provide relative paths in sources (#8281)
Browse files Browse the repository at this point in the history
* Disable ability to provide relative paths in sources

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>

* no relative paths allowed

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>

---------

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
  • Loading branch information
BenWilson2 committed Apr 20, 2023
1 parent 9ac6494 commit f731474
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 1 deletion.
34 changes: 34 additions & 0 deletions mlflow/server/handlers.py
Expand Up @@ -1323,6 +1323,36 @@ def _delete_registered_model_tag():
return _wrap_response(DeleteRegisteredModelTag.Response())


def _validate_non_local_source_contains_relative_paths(source: str):
"""
Validation check to ensure that sources that are provided that conform to the schemes:
http, https, or mlflow-artifacts do not contain relative path designations that are intended
to access local file system paths on the tracking server.
Example paths that this validation function is intended to find and raise an Exception if
passed:
"mlflow-artifacts://host:port/../../../../"
"http://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../"
"https://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../"
"/models/artifacts/../../../"
"s3:/my_bucket/models/path/../../other/path"
"file://path/to/../../../../some/where/you/should/not/be"
"""
source_path = urllib.parse.urlparse(source).path
resolved_source = pathlib.Path(source_path).resolve().as_posix()
# NB: drive split is specifically for Windows since WindowsPath.resolve() will append the
# drive path of the pwd to a given path. We don't care about the drive here, though.
_, resolved_path = os.path.splitdrive(resolved_source)

if resolved_path != source_path:
raise MlflowException(
f"Invalid model version source: '{source}'. If supplying a source as an http, https, "
"local file path, ftp, objectstore, or mlflow-artifacts uri, an absolute path must be "
"provided without relative path references present. Please provide an absolute path.",
INVALID_PARAMETER_VALUE,
)


def _validate_source(source: str, run_id: str) -> None:
if is_local_uri(source):
if run_id:
Expand Down Expand Up @@ -1352,6 +1382,10 @@ def _validate_source(source: str, run_id: str) -> None:
INVALID_PARAMETER_VALUE,
)

# Checks if relative paths are present in the source (a security threat). If any are present,
# raises an Exception.
_validate_non_local_source_contains_relative_paths(source)


@catch_mlflow_exception
@_disable_if_artifacts_only
Expand Down
96 changes: 95 additions & 1 deletion tests/tracking/test_rest_tracking.py
Expand Up @@ -1045,7 +1045,7 @@ def get(self, key, default=None):


def test_create_model_version_with_path_source(mlflow_client):
name = "mode"
name = "model"
mlflow_client.create_registered_model(name)
exp_id = mlflow_client.create_experiment("test")
run = mlflow_client.create_run(experiment_id=exp_id)
Expand Down Expand Up @@ -1084,6 +1084,100 @@ def test_create_model_version_with_path_source(mlflow_client):
assert "To use a local path as a model version" in response.json()["message"]


def test_create_model_version_with_non_local_source(mlflow_client, monkeypatch):
name = "model"
mlflow_client.create_registered_model(name)
exp_id = mlflow_client.create_experiment("test")
run = mlflow_client.create_run(experiment_id=exp_id)

response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": run.info.artifact_uri[len("file://") :],
"run_id": run.info.run_id,
},
)
assert response.status_code == 200

# Test that remote uri's supplied as a source with absolute paths work fine
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "mlflow-artifacts:/models",
"run_id": run.info.run_id,
},
)
assert response.status_code == 200

response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "mlflow-artifacts://host:9000/models",
"run_id": run.info.run_id,
},
)
assert response.status_code == 200

# Test that invalid remote uri's cannot be created
response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "mlflow-artifacts://host:9000/models/../../../",
"run_id": run.info.run_id,
},
)
assert response.status_code == 400
assert "If supplying a source as an http, https," in response.json()["message"]

response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "http://host:9000/models/../../../",
"run_id": run.info.run_id,
},
)
assert response.status_code == 400
assert "If supplying a source as an http, https," in response.json()["message"]

response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "https://host/api/2.0/mlflow-artifacts/artifacts/../../../",
"run_id": run.info.run_id,
},
)
assert response.status_code == 400
assert "If supplying a source as an http, https," in response.json()["message"]

response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "s3a://my_bucket/api/2.0/mlflow-artifacts/artifacts/../../../",
"run_id": run.info.run_id,
},
)
assert response.status_code == 400
assert "If supplying a source as an http, https," in response.json()["message"]

response = requests.post(
f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
json={
"name": name,
"source": "ftp://host:8888/api/2.0/mlflow-artifacts/artifacts/../../../",
"run_id": run.info.run_id,
},
)
assert response.status_code == 400
assert "If supplying a source as an http, https," in response.json()["message"]


def test_create_model_version_with_file_uri(mlflow_client):
name = "test"
mlflow_client.create_registered_model(name)
Expand Down

0 comments on commit f731474

Please sign in to comment.