Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attestation service in Syft #8659

Merged
merged 30 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f7b0661
Merge branch 'rasswanth/cc-attestation' into snwagh/attestation-service
snwagh Apr 3, 2024
8807020
Adding attestation service
snwagh Apr 3, 2024
5eecc92
Updating endpoints to also return raw_token if requested
snwagh Apr 5, 2024
4d57fbb
Adding GPU code for token extraction and combining result and token r…
snwagh Apr 5, 2024
798f8a2
Merge branch 'dev' into snwagh/attestation-service
snwagh Apr 8, 2024
0f2d081
Fixed linting errors
snwagh Apr 8, 2024
a669719
Merge conflicts resolved
snwagh Apr 10, 2024
1569d59
Adding CPU, GPU raw_token argument to endpoints
snwagh Apr 10, 2024
d178db8
Removing notebook for partially completed feature
snwagh Apr 11, 2024
3cc8cf7
Merge branch 'dev' into snwagh/attestation-service
snwagh Apr 11, 2024
2d3094c
Merge branch 'dev' into snwagh/attestation-service
snwagh Apr 12, 2024
ea80271
Merge branch 'tauquir/attestation-helmcharts' into snwagh/attestation…
snwagh Apr 16, 2024
507d294
Merge branch 'tauquir/attestation-helmcharts' into snwagh/attestation…
snwagh Apr 19, 2024
f08065b
Merge branch 'tauquir/attestation-helmcharts' into snwagh/attestation…
snwagh Apr 22, 2024
b9f913e
Merge branch 'tauquir/attestation-helmcharts' into snwagh/attestation…
snwagh Apr 23, 2024
a09fd2f
Adding local verification code in case we revisit enclaves
snwagh Apr 23, 2024
20a60d8
Merge branch 'tauquir/attestation-helmcharts' into snwagh/attestation…
snwagh Apr 25, 2024
b7c012d
Adding notes for deployment of the attestation pieces
snwagh Apr 25, 2024
3e50111
Merge branch 'snwagh/attestation-service' of github.com:OpenMined/PyS…
snwagh Apr 25, 2024
e2080f2
Adding the pending secret item into this PR.
snwagh Apr 25, 2024
1db5bc9
Fixing PR comments
snwagh Apr 30, 2024
d0c2cad
Merge branch 'dev' into snwagh/attestation-service
snwagh Apr 30, 2024
4e0c9b4
Merge branch 'dev' into snwagh/attestation-service
snwagh May 2, 2024
cde3900
Merge branch 'dev' into snwagh/attestation-service
rasswanth-s May 3, 2024
e062859
Addressing Rasswanth's PR comments
snwagh May 3, 2024
dabd1ce
Fixed precommit errors
snwagh May 3, 2024
b82ee87
Merge branch 'dev' into snwagh/attestation-service
snwagh May 3, 2024
249e309
added str_to_bool in attestation service
rasswanth-s May 6, 2024
8c70884
Merge branch 'dev' into snwagh/attestation-service
rasswanth-s May 6, 2024
7c8c19a
Merge branch 'dev' into snwagh/attestation-service
rasswanth-s May 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
104 changes: 104 additions & 0 deletions packages/grid/enclave/attestation/enclave-development.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,107 @@ print ("[RemoteGPUTest] node name :", client.get_name())
client.add_verifier(attestation.Devices.GPU, attestation.Environment.REMOTE, NRAS_URL, "")
client.attest()
```

### Instructions for using helm charts

- The attestation container runs inside the backend pod (so backend pod has two containers now). However, in order to run the attestation container, you need to uncomment the attestation flags in `packages/grid/helm/values.dev.yaml`
- Next, we run the deployment. Since k3d creates an intermediate layer of nesting, we need to mount some volumes from host to k3d registry. Thus, when launching, use the following tox command `tox -e dev.k8s.start -- --volume /sys/kernel/security:/sys/kernel/security --volume /dev/tmprm0:/dev/tmprm0`
- Finally, note that the GPU privileges/drivers etc. have not been completed so while the GPU attestation endpoints should work, they will not produce the expected tokens. To test the GPU code, follow the steps provided in [For GPU Attestation
](#for-gpu-attestation) to look at the tokens.

### Local Client-side Verification

Use the following function to perform local, client-side verification of tokens. They expire quick.

```python3
def verify_token(token: str, token_type: str):
"""
Verifies a JSON Web Token (JWT) using a public key obtained from a JWKS (JSON Web Key Set) endpoint,
based on the specified type of token ('cpu' or 'gpu'). The function handles two distinct processes
for token verification depending on the type specified:

- 'cpu': Fetches the JWKS from the 'jku' URL specified in the JWT's unverified header,
finds the key by 'kid', and converts the JWK to a PEM format public key for verification.

- 'gpu': Directly uses a fixed JWKS URL to retrieve the keys, finds the key by 'kid', and uses the
'x5c' field to extract a certificate which is then used to verify the token.

Parameters:
token (str): The JWT that needs to be verified.
type (str): Type of the token which dictates the verification process; expected values are 'cpu' or 'gpu'.

Returns:
bool: True if the JWT is successfully verified, False otherwise.

Raises:
Exception: Raises various exceptions internally but catches them to return False, except for
printing error messages related to the specific failures (e.g., key not found, invalid certificate).

Example usage:
verify_token('your.jwt.token', 'cpu')
verify_token('your.jwt.token', 'gpu')

Note:
- The function prints out details about the verification process and errors, if any.
- Ensure that the cryptography and PyJWT libraries are properly installed and updated in your environment.
"""
import jwt
import json
import base64
import requests
from jwt.algorithms import RSAAlgorithm
from cryptography.x509 import load_der_x509_certificate
from cryptography.hazmat.primitives import serialization


# Determine JWKS URL based on the token type
if token_type.lower() == "gpu":
jwks_url = 'https://nras.attestation.nvidia.com/.well-known/jwks.json'
else:
unverified_header = jwt.get_unverified_header(token)
jwks_url = unverified_header['jku']

# Fetch the JWKS from the endpoint
jwks = requests.get(jwks_url).json()

# Get the key ID from the JWT header
header = jwt.get_unverified_header(token)
kid = header['kid']

# Find the key with the matching kid in the JWKS
key = next((item for item in jwks["keys"] if item["kid"] == kid), None)
if not key:
print("Public key not found in JWKS list.")
return False

# Convert the key based on the token type
if token_type.lower() == "gpu" and "x5c" in key:
try:
cert_bytes = base64.b64decode(key['x5c'][0])
cert = load_der_x509_certificate(cert_bytes)
public_key = cert.public_key()
except Exception as e:
print("Failed to process certificate:", str(e))
return False
elif token_type.lower() == "cpu":
try:
public_key = RSAAlgorithm.from_jwk(key)
except Exception as e:
print("Failed to convert JWK to PEM:", str(e))
return False
else:
print("Invalid token_type or key information.")
return False

# Verify the JWT using the public key
try:
payload = jwt.decode(token, public_key, algorithms=[header['alg']], options={"verify_exp": True})
print("JWT Payload:", json.dumps(payload, indent=2))
return True
except jwt.ExpiredSignatureError:
print("JWT token has expired.")
except jwt.InvalidTokenError as e:
print("JWT token signature is invalid:", str(e))

return False
```
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
NRAS_URL = "https://nras.attestation.nvidia.com/v1/attest/gpu"
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from loguru import logger

# relative
from .attestation_models import CPUAttestationResponseModel
from .attestation_models import GPUAttestationResponseModel
from .attestation_models import ResponseModel
from .cpu_attestation import attest_cpu
from .gpu_attestation import attest_gpu
from .models import CPUAttestationResponseModel
from .models import GPUAttestationResponseModel
from .models import ResponseModel

# Logging Configuration
log_level = os.getenv("APP_LOG_LEVEL", "INFO").upper()
Expand All @@ -28,11 +28,11 @@ async def read_root() -> ResponseModel:

@app.get("/attest/cpu", response_model=CPUAttestationResponseModel)
async def attest_cpu_endpoint() -> CPUAttestationResponseModel:
cpu_attest_res = attest_cpu()
return CPUAttestationResponseModel(result=cpu_attest_res)
cpu_attest_res, cpu_attest_token = attest_cpu()
return CPUAttestationResponseModel(result=cpu_attest_res, token=cpu_attest_token)


@app.get("/attest/gpu", response_model=GPUAttestationResponseModel)
async def attest_gpu_endpoint() -> GPUAttestationResponseModel:
gpu_attest_res = attest_gpu()
return GPUAttestationResponseModel(result=gpu_attest_res)
gpu_attest_res, gpu_attest_token = attest_gpu()
return GPUAttestationResponseModel(result=gpu_attest_res, token=gpu_attest_token)
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ class ResponseModel(BaseModel):

class CPUAttestationResponseModel(BaseModel):
result: str
token: str = ""
vendor: str | None = None # Hardware Manufacturer


class GPUAttestationResponseModel(BaseModel):
result: str
token: str = ""
vendor: str | None = None # Hardware Manufacturer
15 changes: 12 additions & 3 deletions packages/grid/enclave/attestation/server/cpu_attestation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from loguru import logger


def attest_cpu() -> str:
def attest_cpu() -> tuple[str, str]:
# Fetch report from Micrsoft Attestation library
cpu_report = subprocess.run(
["/app/AttestationClient"], capture_output=True, text=True
Expand All @@ -14,7 +14,16 @@ def attest_cpu() -> str:
logger.debug(f"Stderr: {cpu_report.stderr}")

logger.info("Attestation Return Code: {}", cpu_report.returncode)
res = "False"
if cpu_report.returncode == 0 and cpu_report.stdout == "true":
return "True"
res = "True"

return "False"
# Fetch token from Micrsoft Attestation library
cpu_token = subprocess.run(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving a note , for later, we do two calls, to the attestation client library , we could combine them in to a single call in later PR's

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will shelf this as a TODO item, good point though.

["/app/AttestationClient", "-o", "token"], capture_output=True, text=True
)
logger.debug(f"Stdout: {cpu_token.stdout}")
logger.debug(f"Stderr: {cpu_token.stderr}")

logger.info("Attestation Token Return Code: {}", cpu_token.returncode)
return res, cpu_token.stdout
40 changes: 36 additions & 4 deletions packages/grid/enclave/attestation/server/gpu_attestation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
# standard imports
snwagh marked this conversation as resolved.
Show resolved Hide resolved
# stdlib
import io
import re
import sys

# third party
from loguru import logger
from nv_attestation_sdk import attestation

NRAS_URL = "https://nras.attestation.nvidia.com/v1/attest/gpu"
# relative
# relative imports
snwagh marked this conversation as resolved.
Show resolved Hide resolved
from .attestation_constants import NRAS_URL


# Function to process captured output to extract the token
def extract_token(captured_value: str) -> str:
match = re.search(r"Entity Attestation Token is (\S+)", captured_value)
if match:
token = match.group(1) # Extract the token, which is in group 1 of the match
return token
else:
return "Token not found"


def attest_gpu() -> str:
def attest_gpu() -> tuple[str, str]:
# Fetch report from Nvidia Attestation SDK
client = attestation.Attestation("Attestation Node")

Expand All @@ -15,7 +33,21 @@ def attest_gpu() -> str:
client.add_verifier(
attestation.Devices.GPU, attestation.Environment.REMOTE, NRAS_URL, ""
)

# Step 1: Redirect stdout
original_stdout = sys.stdout # Save a reference to the original standard output
captured_output = io.StringIO() # Create a StringIO object to capture output
sys.stdout = captured_output # Redirect stdout to the StringIO object

# Step 2: Call the function
gpu_report = client.attest()
logger.info("[RemoteGPUTest] report : {}, {}", gpu_report, type(gpu_report))

return str(gpu_report)
# Step 3: Get the content of captured output and reset stdout
captured_value = captured_output.getvalue()
sys.stdout = original_stdout # Reset stdout to its original state

# Step 4: Extract the token from the captured output
token = extract_token(captured_value)

logger.info("[RemoteGPUTest] report : {}, {}", gpu_report, type(gpu_report))
return str(gpu_report), token
2 changes: 1 addition & 1 deletion packages/grid/enclave/attestation/start.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
set -e
export PATH="/root/.local/bin:${PATH}"

APP_MODULE=server.main:app
APP_MODULE=server.attestation_main:app
snwagh marked this conversation as resolved.
Show resolved Hide resolved
APP_LOG_LEVEL=${APP_LOG_LEVEL:-info}
UVICORN_LOG_LEVEL=${UVICORN_LOG_LEVEL:-info}
HOST=${HOST:-0.0.0.0}
Expand Down
6 changes: 6 additions & 0 deletions packages/syft/src/syft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,11 @@ def _orchestra() -> Orchestra:
return Orchestra


@module_property
def hello_baby() -> None:
print("Hello baby!")
print("Welcome to the world. \u2764\ufe0f")


snwagh marked this conversation as resolved.
Show resolved Hide resolved
def search(name: str) -> SearchResults:
return Search(_domains()).search(name=name)
2 changes: 2 additions & 0 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ..service.action.action_store import MongoActionStore
from ..service.action.action_store import SQLiteActionStore
from ..service.api.api_service import APIService
from ..service.attestation.attestation_service import AttestationService
from ..service.blob_storage.service import BlobStorageService
from ..service.code.status_service import UserCodeStatusService
from ..service.code.user_code_service import UserCodeService
Expand Down Expand Up @@ -877,6 +878,7 @@ def _construct_services(self) -> None:
default_services: list[dict] = [
{"svc": ActionService, "store": self.action_store},
{"svc": UserService},
{"svc": AttestationService},
{"svc": WorkerService},
{"svc": SettingsService},
{"svc": DatasetService},
Expand Down
rasswanth-s marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
ATTESTATION_SERVICE_URL = (
"http://localhost:4455" # Replace with "http://attestation:4455"
)
ATTEST_CPU_ENDPOINT = "/attest/cpu"
ATTEST_GPU_ENDPOINT = "/attest/gpu"
60 changes: 60 additions & 0 deletions packages/syft/src/syft/service/attestation/attestation_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# stdlib
from collections.abc import Callable

# third party
import requests

# relative
from ...serde.serializable import serializable
from ...store.document_store import DocumentStore
from ..context import AuthedServiceContext
from ..response import SyftError
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import service_method
from ..user.user_roles import GUEST_ROLE_LEVEL
from .attestation_constants import ATTESTATION_SERVICE_URL
from .attestation_constants import ATTEST_CPU_ENDPOINT
from .attestation_constants import ATTEST_GPU_ENDPOINT


@serializable()
class AttestationService(AbstractService):
"""This service is responsible for getting all sorts of attestations for any client."""

def __init__(self, store: DocumentStore) -> None:
self.store = store

def perform_request(
self, method: Callable, endpoint: str, raw: bool = False
) -> SyftSuccess | SyftError | str:
try:
response = method(f"{ATTESTATION_SERVICE_URL}{endpoint}")
response.raise_for_status()
message = response.json().get("result")
raw_token = response.json().get("token")
return raw_token if raw else SyftSuccess(message=message)
snwagh marked this conversation as resolved.
Show resolved Hide resolved
except requests.HTTPError:
return SyftError(message=f"{response.json()['detail']}")
except requests.RequestException as e:
return SyftError(message=f"Failed to perform request. {e}")

@service_method(
path="attestation.get_cpu_attestation",
name="get_cpu_attestation",
roles=GUEST_ROLE_LEVEL,
)
def get_cpu_attestation(
self, context: AuthedServiceContext, raw_token: bool = False
) -> str | SyftError | SyftSuccess:
return self.perform_request(requests.get, ATTEST_CPU_ENDPOINT, raw_token)

@service_method(
path="attestation.get_gpu_attestation",
name="get_gpu_attestation",
roles=GUEST_ROLE_LEVEL,
)
def get_gpu_attestation(
self, context: AuthedServiceContext, raw_token: bool = False
) -> str | SyftError | SyftSuccess:
return self.perform_request(requests.get, ATTEST_GPU_ENDPOINT, raw_token)