From f73147496e05c09a8b83d95fb4f1bf86696c6342 Mon Sep 17 00:00:00 2001 From: Ben Wilson <39283302+BenWilson2@users.noreply.github.com> Date: Thu, 20 Apr 2023 09:51:51 -0400 Subject: [PATCH] Disable ability to provide relative paths in sources (#8281) * Disable ability to provide relative paths in sources Signed-off-by: Ben Wilson * no relative paths allowed Signed-off-by: Ben Wilson --------- Signed-off-by: Ben Wilson --- mlflow/server/handlers.py | 34 ++++++++++ tests/tracking/test_rest_tracking.py | 96 +++++++++++++++++++++++++++- 2 files changed, 129 insertions(+), 1 deletion(-) diff --git a/mlflow/server/handlers.py b/mlflow/server/handlers.py index d7423f50fea69..7c2f55183eb5e 100644 --- a/mlflow/server/handlers.py +++ b/mlflow/server/handlers.py @@ -1323,6 +1323,36 @@ def _delete_registered_model_tag(): return _wrap_response(DeleteRegisteredModelTag.Response()) +def _validate_non_local_source_contains_relative_paths(source: str): + """ + Validation check to ensure that sources that are provided that conform to the schemes: + http, https, or mlflow-artifacts do not contain relative path designations that are intended + to access local file system paths on the tracking server. + + Example paths that this validation function is intended to find and raise an Exception if + passed: + "mlflow-artifacts://host:port/../../../../" + "http://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../" + "https://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../" + "/models/artifacts/../../../" + "s3:/my_bucket/models/path/../../other/path" + "file://path/to/../../../../some/where/you/should/not/be" + """ + source_path = urllib.parse.urlparse(source).path + resolved_source = pathlib.Path(source_path).resolve().as_posix() + # NB: drive split is specifically for Windows since WindowsPath.resolve() will append the + # drive path of the pwd to a given path. We don't care about the drive here, though. + _, resolved_path = os.path.splitdrive(resolved_source) + + if resolved_path != source_path: + raise MlflowException( + f"Invalid model version source: '{source}'. If supplying a source as an http, https, " + "local file path, ftp, objectstore, or mlflow-artifacts uri, an absolute path must be " + "provided without relative path references present. Please provide an absolute path.", + INVALID_PARAMETER_VALUE, + ) + + def _validate_source(source: str, run_id: str) -> None: if is_local_uri(source): if run_id: @@ -1352,6 +1382,10 @@ def _validate_source(source: str, run_id: str) -> None: INVALID_PARAMETER_VALUE, ) + # Checks if relative paths are present in the source (a security threat). If any are present, + # raises an Exception. + _validate_non_local_source_contains_relative_paths(source) + @catch_mlflow_exception @_disable_if_artifacts_only diff --git a/tests/tracking/test_rest_tracking.py b/tests/tracking/test_rest_tracking.py index c8f4a2b6c552f..5748590d8a131 100644 --- a/tests/tracking/test_rest_tracking.py +++ b/tests/tracking/test_rest_tracking.py @@ -1045,7 +1045,7 @@ def get(self, key, default=None): def test_create_model_version_with_path_source(mlflow_client): - name = "mode" + name = "model" mlflow_client.create_registered_model(name) exp_id = mlflow_client.create_experiment("test") run = mlflow_client.create_run(experiment_id=exp_id) @@ -1084,6 +1084,100 @@ def test_create_model_version_with_path_source(mlflow_client): assert "To use a local path as a model version" in response.json()["message"] +def test_create_model_version_with_non_local_source(mlflow_client, monkeypatch): + name = "model" + mlflow_client.create_registered_model(name) + exp_id = mlflow_client.create_experiment("test") + run = mlflow_client.create_run(experiment_id=exp_id) + + response = requests.post( + f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", + json={ + "name": name, + "source": run.info.artifact_uri[len("file://") :], + "run_id": run.info.run_id, + }, + ) + assert response.status_code == 200 + + # Test that remote uri's supplied as a source with absolute paths work fine + response = requests.post( + f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", + json={ + "name": name, + "source": "mlflow-artifacts:/models", + "run_id": run.info.run_id, + }, + ) + assert response.status_code == 200 + + response = requests.post( + f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", + json={ + "name": name, + "source": "mlflow-artifacts://host:9000/models", + "run_id": run.info.run_id, + }, + ) + assert response.status_code == 200 + + # Test that invalid remote uri's cannot be created + response = requests.post( + f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", + json={ + "name": name, + "source": "mlflow-artifacts://host:9000/models/../../../", + "run_id": run.info.run_id, + }, + ) + assert response.status_code == 400 + assert "If supplying a source as an http, https," in response.json()["message"] + + response = requests.post( + f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", + json={ + "name": name, + "source": "http://host:9000/models/../../../", + "run_id": run.info.run_id, + }, + ) + assert response.status_code == 400 + assert "If supplying a source as an http, https," in response.json()["message"] + + response = requests.post( + f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", + json={ + "name": name, + "source": "https://host/api/2.0/mlflow-artifacts/artifacts/../../../", + "run_id": run.info.run_id, + }, + ) + assert response.status_code == 400 + assert "If supplying a source as an http, https," in response.json()["message"] + + response = requests.post( + f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", + json={ + "name": name, + "source": "s3a://my_bucket/api/2.0/mlflow-artifacts/artifacts/../../../", + "run_id": run.info.run_id, + }, + ) + assert response.status_code == 400 + assert "If supplying a source as an http, https," in response.json()["message"] + + response = requests.post( + f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create", + json={ + "name": name, + "source": "ftp://host:8888/api/2.0/mlflow-artifacts/artifacts/../../../", + "run_id": run.info.run_id, + }, + ) + assert response.status_code == 400 + assert "If supplying a source as an http, https," in response.json()["message"] + + def test_create_model_version_with_file_uri(mlflow_client): name = "test" mlflow_client.create_registered_model(name)