Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to do clustering grid search with multiple CPUs / GPUs? #337

Open
jamesaphoenix opened this issue Mar 27, 2024 · 2 comments
Open

How to do clustering grid search with multiple CPUs / GPUs? #337

jamesaphoenix opened this issue Mar 27, 2024 · 2 comments

Comments

@jamesaphoenix
Copy link

Currently i'm building a wasm project that will expose some clustering functionality to the browser.

Questions:

  • Do we have grid search functionality? Or should I simply loop over multiple model.fit calls sequentially?
  • What's the easiest way to implement this for 3x clustering techniques?

I'm looking to use all of these:
https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/dbscan.rs
https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/kmeans.rs
https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/optics.rs

  • Also do I need to implement multi-core processing similar to joblib in python? Or is this handled by linfa?

Thanks in advance, and great package btw!

@jamesaphoenix
Copy link
Author

Here is my current library to provide some context

use linfa::traits::Fit;
use linfa::traits::Predict;
use linfa::DatasetBase;
use linfa_clustering::KMeans;
use linfa_nn::distance::LInfDist;
use ndarray::Array2;
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use serde::{Deserialize, Serialize};
use serde_json;
use wasm_bindgen::prelude::*;

// Data types:
#[derive(Serialize, Deserialize)]
struct Embedding {
    keyword: String,
    embeddings: Vec<f64>,
}

#[derive(Serialize, Deserialize)]
struct EnrichedEmbedding {
    embedding: Embedding,
    cluster: usize,
    is_main_keyword_in_cluster: bool,
}

#[wasm_bindgen]
extern "C" {
    #[wasm_bindgen(js_namespace = console)]
    fn log(s: &str);
}

#[wasm_bindgen]
pub fn greet(name: &str) -> String {
    format!("Hello, {}!", name)
}

// TODO - If there are no keywords then raise an error:

#[wasm_bindgen]
pub fn cluster_embeddings_with_kmeans(
    json_embeddings: &str,
    n_clusters: usize,
) -> Result<String, JsValue> {
    let rng = Xoshiro256Plus::seed_from_u64(42);

    // Deserialize JSON embeddings:
    let embeddings: Vec<Embedding> =
        serde_json::from_str(json_embeddings).map_err(|e| JsValue::from_str(&e.to_string()))?;

    println!("Number of embeddings: {}", embeddings.len());

    // If there are more than 100,000 embeddings:
    if embeddings.len() > 100000 {
        return Err(JsValue::from_str(
            "The number of embeddings is too large. Please use a smaller dataset.",
        ));
    }

    if embeddings.len() == 0 {
        return Err(JsValue::from_str(
            "The number of embeddings is 0. Please provide some embeddings.",
        ));
    }

    // Convert embeddings to ndarray
    let rows = embeddings.len();
    let cols = embeddings[0].embeddings.len();
    let flattened: Vec<f64> = embeddings
        .iter()
        .flat_map(|e| e.embeddings.clone())
        .collect();
    let array = Array2::from_shape_vec((rows, cols), flattened)
        .map_err(|e| JsValue::from_str(&e.to_string()))?;
    let dataset = DatasetBase::from(array);

    log("Clustering embeddings in Rust...");

    // Cluster embeddings in Rust:
    let model = KMeans::params_with(n_clusters, rng, LInfDist)
        .max_n_iterations(1000)
        .fit(&dataset)
        .map_err(|e| JsValue::from_str(&e.to_string()))?;

    log("Finished clustering embeddings in Rust");
    log("Assigning points to clusters...");

    // Assign each point to a cluster using the set of centroids found using `fit`
    let dataset = model.predict(dataset);
    let DatasetBase {
        records, targets, ..
    } = dataset;

    // Assuming you want to correlate the original embeddings with their cluster assignments
    let enriched_embeddings: Vec<EnrichedEmbedding> = embeddings
        .into_iter()
        .zip(targets.iter())
        .map(|(embedding, &cluster)| {
            EnrichedEmbedding {
                embedding,
                cluster: cluster as usize,
                is_main_keyword_in_cluster: false, // Placeholder logic here
            }
        })
        .collect();

    // Serialize the enriched embeddings
    serde_json::to_string(&enriched_embeddings).map_err(|e| JsValue::from_str(&e.to_string()))
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray_rand::rand::rngs::mock;
    use wasm_bindgen_test::*;
    use web_sys::console::assert;

    #[test]
    fn testing_greeting() {
        assert_eq!(greet("world"), "Hello, world!");
    }

    #[wasm_bindgen_test]
    fn test_cluster_embeddings() {
        let mock_json = r#"
            [
                {
                    "keyword": "rust",
                    "embeddings": [0.1, 0.2, 0.3]
                },
                {
                    "keyword": "wasm",
                    "embeddings": [0.4, 0.5, 0.6]
                }
            ]
        "#;

        let n_clusters = 2; // For simplicity, choose a small number of clusters

        // Call the function with the mocked JSON and the number of clusters
        let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);

        // Check that the function succeeded
        assert!(result.is_ok());

        // Deserialize the result to verify its structure
        let enriched_embeddings: Vec<EnrichedEmbedding> =
            serde_json::from_str(&result.unwrap()).unwrap();

        // Verify that each embedding has been assigned a cluster
        assert_eq!(enriched_embeddings.len(), 2);
        for enriched_embedding in enriched_embeddings {
            assert!(enriched_embedding.cluster < n_clusters);
        }
    }

    #[wasm_bindgen_test]
    fn test_cluster_embeddings_with_no_embeddings() {
        let mock_json = r#"
            []
        "#;

        let n_clusters = 2; // For simplicity, choose a small number of clusters
        let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);
        assert!(result.is_err())
    }

    #[wasm_bindgen_test]
    fn test_cluster_embeddings_with_large_dataset() {
        // Mock over 100k embeddings to trigger an error:
        let mock_json = r#"
                    {
                    "keyword": "rust",
                    "embeddings": [0.1, 0.2, 0.3]
                },
                {
                    "keyword": "wasm",
                    "embeddings": [0.4, 0.5, 0.6]
                }
            "#;

        // Now make the mock_json a string of 100k embeddings:
        let mut mock_json_new = String::from("[");
        for _ in 0..100000 {
            mock_json_new.push_str(&mock_json);
        }
        mock_json_new.push_str("]");
        let n_clusters = 2; // For simplicity, choose a small number of clusters

        // Call the function with the mocked JSON and the number of clusters
        let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);

        // Check that the function failed:
        assert!(result.is_err());

        #[wasm_bindgen_test]
        fn test_cluster_embeddings_with_3k_embeddings() {
            let mut mock_json_new = String::from("[");
            let single_embedding = r#"{"keyword": "rust", "embeddings": [0.1, 0.2, 0.3]}"#;
            for i in 0..3000 {
                if i > 0 {
                    mock_json_new.push(',');
                }
                mock_json_new.push_str(single_embedding);
            }
            mock_json_new.push(']');

            let n_clusters = 2; // For simplicity, choose a small number of clusters

            // Call the function with the mocked JSON and the number of clusters
            let result = cluster_embeddings_with_kmeans(&mock_json_new, n_clusters);
            assert!(result.is_ok());
        }

        // Call the function with the mocked JSON and the number of clusters
    }
}

@jamesaphoenix
Copy link
Author

Bump on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant