Skip to content

Commit

Permalink
CRC: make data mapping consistent with RiskSeverity enum, remove unne…
Browse files Browse the repository at this point in the history
…cessary intermediate struct: RiskResponse
  • Loading branch information
Marzi committed May 13, 2024
1 parent 15c0669 commit 42db0b7
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 94 deletions.
20 changes: 8 additions & 12 deletions blocklist-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,18 @@ path = "src/main.rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
config = "0.11"
once_cell = "1.8.0"
reqwest = { version = "0.11", features = ["json"] }
sbtc-signer = { path = "../signer" }
serde.workspace = true
serde_json.workspace = true
thiserror.workspace = true
tokio = { version = "1.32.0", features = ["rt-multi-thread", "rt", "macros"] }
tracing.workspace = true
tracing-attributes.workspace = true
serde.workspace = true
serde_json.workspace = true
tracing-subscriber = { version = "0.3", default-features = false, features = ["env-filter", "fmt", "json", "time"] }
warp = "0.3"
tokio = { version = "1.32.0", features = ["rt-multi-thread", "rt", "macros"] }
reqwest = { version = "0.11", features = ["json"] }
config = "0.11"
once_cell = "1.8.0"
sbtc-signer = { path = "../signer" }

[dependencies.tracing-subscriber]
version = "0.3"
default-features = false
features = ["env-filter", "fmt", "json", "time"]

[dev-dependencies]
mockito = "0.28"
102 changes: 39 additions & 63 deletions blocklist-client/src/client/risk_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::common::{BlocklistStatus, RiskSeverity};
use crate::config::RiskAnalysisConfig;
use reqwest::{Client, Response, StatusCode};
use serde::Deserialize;
use std::error::Error as StdError;
use tracing::debug;
const API_BASE_PATH: &str = "/api/risk/v2/entities";

Expand All @@ -12,16 +13,11 @@ struct RegistrationResponse {
}

#[derive(Deserialize, Debug)]
struct RiskResponse {
risk: Option<String>,
#[serde(rename = "riskReason")]
risk_reason: Option<String>,
}

#[derive(Deserialize, Debug)]
struct RiskAssessment {
pub struct RiskAssessment {
#[serde(rename = "risk")]
pub severity: RiskSeverity,
pub reason: String,
#[serde(rename = "riskReason")]
pub reason: Option<String>,
}

/// Register the user address with provider to run subsequent risk checks
Expand Down Expand Up @@ -66,23 +62,25 @@ async fn get_risk_assessment(
.await?;

let checked_response = check_api_response(response).await?;
let resp = checked_response.json::<RiskResponse>().await?;

match resp.risk {
Some(risk_str) => {
let severity = match risk_str.as_str() {
"Low" => RiskSeverity::Low,
"Medium" => RiskSeverity::Medium,
"High" => RiskSeverity::High,
"Severe" => RiskSeverity::Severe,
_ => return Err(Error::InvalidRiskValue(risk_str)),
};
Ok(RiskAssessment {
severity,
reason: resp.risk_reason.unwrap_or_default(),
})
let resp_result = checked_response.json::<RiskAssessment>().await;

match resp_result {
Ok(resp) => Ok(resp),
Err(e) if e.is_decode() => {
// Check if the source of the error is serde_json::Error
if let Some(serde_err) = e
.source()
.and_then(|cause| cause.downcast_ref::<serde_json::Error>())
{
match serde_err.classify() {
serde_json::error::Category::Data => Err(Error::InvalidApiResponse),
_ => Err(Error::Serialization(serde_err.to_string())),
}
} else {
Err(Error::Network(e))
}
}
None => Err(Error::InvalidApiResponse),
Err(e) => Err(Error::Network(e)),
}
}

Expand All @@ -100,15 +98,15 @@ pub async fn check_address(
// If registration is successful, proceed to check the address
let RiskAssessment { severity, reason } = get_risk_assessment(client, config, address).await?;
debug!(
"Received risk assessment: Severity = {}, Reason = {}",
"Received risk assessment: Severity = {}, Reason = {:?}",
severity, reason
);

let is_severe = matches!(severity, RiskSeverity::Severe);
let is_severe = severity.is_severe();
let blocklist_status = BlocklistStatus {
// `is_blocklisted` is set to true if risk is Severe
is_blocklisted: is_severe,
severity: severity.to_string(),
severity,
// `accept` is set to false if severity is Severe
accept: !is_severe,
reason,
Expand All @@ -121,18 +119,18 @@ pub async fn check_address(
async fn check_api_response(response: Response) -> Result<Response, Error> {
match response.status() {
StatusCode::OK | StatusCode::CREATED => Ok(response),
StatusCode::BAD_REQUEST => Err(Error::HttpRequestErr(
StatusCode::BAD_REQUEST => Err(Error::HttpRequest(
response.status(),
"Bad request - Invalid parameters or data".to_string(),
)),
StatusCode::FORBIDDEN => Err(Error::Unauthorized),
StatusCode::NOT_FOUND => Err(Error::NotFound),
StatusCode::NOT_ACCEPTABLE => Err(Error::NotAcceptable),
StatusCode::CONFLICT => Err(Error::Conflict),
StatusCode::INTERNAL_SERVER_ERROR => Err(Error::InternalServerErr),
StatusCode::INTERNAL_SERVER_ERROR => Err(Error::InternalServer),
StatusCode::SERVICE_UNAVAILABLE => Err(Error::ServiceUnavailable),
StatusCode::REQUEST_TIMEOUT => Err(Error::RequestTimeout),
status => Err(Error::HttpRequestErr(
status => Err(Error::HttpRequest(
status,
"Unhandled status code".to_string(),
)),
Expand Down Expand Up @@ -200,11 +198,11 @@ mod tests {

let result = register_address(&client, &config, TEST_ADDRESS).await;
match result {
Err(Error::HttpRequestErr(code, message)) => {
Err(Error::HttpRequest(code, message)) => {
assert_eq!(code, StatusCode::BAD_REQUEST);
assert!(message.contains("Bad request - Invalid parameters or data"));
}
_ => panic!("Expected HttpRequestErr, got {:?}", result),
_ => panic!("Expected HttpRequest, got {:?}", result),
}
}

Expand Down Expand Up @@ -249,28 +247,6 @@ mod tests {
}
}

#[tokio::test]
async fn test_get_risk_assessment_invalid_risk_value() {
let _m = setup_mock(
"GET",
format!("{}/{}", API_BASE_PATH, TEST_ADDRESS).as_str(),
200,
r#"{"risk": "mild"}"#,
);
let (client, config) = setup_client();

let result = get_risk_assessment(&client, &config, TEST_ADDRESS).await;
match result {
Ok(_) => panic!("Test failed: Expected an Error::InvalidRiskValue, but got Ok"),
Err(e) => match e {
Error::InvalidRiskValue(_) => {
assert!(true, "Received the expected Error::InvalidRiskValue");
}
_ => panic!("Test failed: Expected Error::InvalidRiskValue, got {e:?}"),
},
}
}

#[tokio::test]
async fn test_check_address_blocklisted_for_high_risk() {
let _reg_mock = setup_mock("POST", API_BASE_PATH, 200, ADDRESS_REGISTRATION_BODY);
Expand All @@ -286,8 +262,8 @@ mod tests {
assert!(result.is_ok());
let status = result.unwrap();
assert!(status.is_blocklisted);
assert_eq!(status.severity, Severe.to_string());
assert_eq!(status.reason, "fraud");
assert_eq!(status.severity, Severe);
assert_eq!(status.reason, Some("fraud".to_string()));
assert!(!status.accept);
}

Expand All @@ -306,8 +282,8 @@ mod tests {
assert!(result.is_ok());
let status = result.unwrap();
assert!(!status.is_blocklisted);
assert_eq!(status.severity, Low.to_string());
assert!(status.reason.is_empty());
assert_eq!(status.severity, Low);
assert!(status.reason.is_none());
assert!(status.accept);
}

Expand All @@ -324,8 +300,8 @@ mod tests {
let result = check_address(&client, &config, TEST_ADDRESS).await;
assert!(result.is_err());
match result {
Err(Error::HttpRequestErr(code, _)) => assert_eq!(code, StatusCode::BAD_REQUEST),
_ => panic!("Expected HttpRequestErr for bad registration"),
Err(Error::HttpRequest(code, _)) => assert_eq!(code, StatusCode::BAD_REQUEST),
_ => panic!("Expected HttpRequest for bad registration"),
}
}

Expand All @@ -343,10 +319,10 @@ mod tests {
let result = check_address(&client, &config, TEST_ADDRESS).await;
assert!(result.is_err());
match result {
Err(Error::InternalServerErr) => {
Err(Error::InternalServer) => {
assert!(true, "Received expected internal server error")
}
_ => panic!("Expected InternalServerErr for failed risk assessment"),
_ => panic!("Expected InternalServer for failed risk assessment"),
}
}
}
23 changes: 8 additions & 15 deletions blocklist-client/src/common/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,17 @@ use warp::reject::Reject;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("HTTP request failed with status code {0}: {1}")]
HttpRequestErr(StatusCode, String),
HttpRequest(StatusCode, String),

#[error("Network error: {0}")]
NetworkErr(#[from] reqwest::Error),
Network(#[from] reqwest::Error),

#[error("Serialization error: {0}")]
SerializationErr(#[from] serde_json::Error),
Serialization(String),

#[error("Invalid API response structure")]
InvalidApiResponse,

#[error("Invalid risk value provided: {0}")]
InvalidRiskValue(String),

#[error("Unauthorized access - check your API key")]
Unauthorized,

Expand All @@ -32,7 +29,7 @@ pub enum Error {
Conflict,

#[error("Internal server error")]
InternalServerErr,
InternalServer,

#[error("Service unavailable")]
ServiceUnavailable,
Expand All @@ -44,9 +41,9 @@ pub enum Error {
impl Error {
pub fn as_http_response(&self) -> (StatusCode, String) {
match self {
Error::HttpRequestErr(code, msg) => (*code, msg.clone()),
Error::NetworkErr(_) => (StatusCode::BAD_GATEWAY, "Network error".to_string()),
Error::SerializationErr(_) => (
Error::HttpRequest(code, msg) => (*code, msg.clone()),
Error::Network(_) => (StatusCode::BAD_GATEWAY, "Network error".to_string()),
Error::Serialization(_) => (
StatusCode::BAD_REQUEST,
"Error in processing the data".to_string(),
),
Expand All @@ -64,7 +61,7 @@ impl Error {
"Not acceptable format requested".to_string(),
),
Error::Conflict => (StatusCode::CONFLICT, "Request conflict".to_string()),
Error::InternalServerErr => (
Error::InternalServer => (
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error".to_string(),
),
Expand All @@ -73,10 +70,6 @@ impl Error {
"Service unavailable".to_string(),
),
Error::RequestTimeout => (StatusCode::REQUEST_TIMEOUT, "Request timeout".to_string()),
Error::InvalidRiskValue(_) => (
StatusCode::BAD_REQUEST,
"Invalid API response risk value".to_string(),
),
}
}
}
Expand Down
14 changes: 10 additions & 4 deletions blocklist-client/src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@ use std::fmt;
pub mod error;

/// The BlocklistStatus of a user address
#[derive(Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BlocklistStatus {
// Whether the address is blocklisted or not
pub is_blocklisted: bool,
// The risk severity associated with an address
pub severity: String,
pub severity: RiskSeverity,
// Blocklist client's acceptance decision based on the risk severity of the address
pub accept: bool,
// Reason for the acceptance decision
pub reason: String,
pub reason: Option<String>,
}

/// Risk severity linked to an address
#[derive(Debug, PartialEq, Eq, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum RiskSeverity {
Low,
Medium,
Expand All @@ -35,3 +35,9 @@ impl fmt::Display for RiskSeverity {
}
}
}

impl RiskSeverity {
pub fn is_severe(&self) -> bool {
matches!(self, RiskSeverity::Severe)
}
}

0 comments on commit 42db0b7

Please sign in to comment.