Skip to content

mimiquate/candle

 
 

Repository files navigation

candle

discord server Latest version Documentation License

Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use. Try our online demos: whisper, LLaMA2, T5, yolo, Segment Anything.

Get started

Make sure that you have candle-core correctly installed as described in Installation.

Let's see how to run a simple matrix multiplication. Write the following to your myapp/src/main.rs file:

use candle_core::{Device, Tensor};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let device = Device::Cpu;

    let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
    let b = Tensor::randn(0f32, 1., (3, 4), &device)?;

    let c = a.matmul(&b)?;
    println!("{c}");
    Ok(())
}

cargo run should display a tensor of shape Tensor[[2, 4], f32].

Having installed candle with Cuda support, simply define the device to be on GPU:

- let device = Device::Cpu;
+ let device = Device::new_cuda(0)?;

For more advanced examples, please have a look at the following section.

Check out our examples

These online demos run entirely in your browser:

We also provide a some command line based examples using state of the art models:

  • LLaMA and LLaMA-v2: general LLM, includes the SOLAR-10.7B variant.
  • Falcon: general LLM.
  • Phi-1, Phi-1.5, and Phi-2: 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
  • StableLM-3B-4E1T: a 3b general LLM pre-trained on 1T tokens of English and code datasets.
  • Minimal Mamba: a minimal implementation of the Mamba state space model.
  • Mistral7b-v0.1: a 7b general LLM with better performance than all publicly available 13b models as of 2023-09-28.
  • Mixtral8x7b-v0.1: a sparse mixture of experts 8x7b general LLM with better performance than a Llama 2 70B model with much faster inference.
  • StarCoder: LLM specialized to code generation.
  • Replit-code-v1.5: a 3.3b LLM specialized for code completion.
  • Yi-6B / Yi-34B: two bilingual (English/Chinese) general LLMs with 6b and 34b parameters.
  • Quantized LLaMA: quantized version of the LLaMA model using the same quantization techniques as llama.cpp.

  • Stable Diffusion: text to image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.

  • Wuerstchen: another text to image generative model.

  • Whisper: speech recognition model.
  • T5, Bert, JinaBert : useful for sentence embeddings.
  • DINOv2: computer vision model trained using self-supervision (can be used for imagenet classification, depth evaluation, segmentation).
  • BLIP: image to text model, can be used to generate captions for an image.
  • Marian-MT: neural machine translation model, generates the translated text from the input text.

Run them using commands like:

cargo run --example quantized --release

In order to use CUDA add --features cuda to the example command line. If you have cuDNN installed, use --features cudnn for even more speedups.

There are also some wasm examples for whisper and llama2.c. You can either build them with trunk or try them online: whisper, llama2, T5, Phi-1.5, and Phi-2, Segment Anything Model.

For LLaMA2, run the following command to retrieve the weight files and start a test server:

cd candle-wasm-examples/llama2-c
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
trunk serve --release --port 8081

And then head over to http://localhost:8081/.

Useful External Resources

  • candle-tutorial: A very detailed tutorial showing how to convert a PyTorch model to Candle.
  • candle-lora: Efficient and ergonomic LoRA implementation for Candle. candle-lora has
    out-of-the-box LoRA support for many models from Candle, which can be found here.
  • optimisers: A collection of optimisers including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
  • candle-vllm: Efficient platform for inference and serving local LLMs including an OpenAI compatible API server.
  • candle-ext: An extension library to Candle that provides PyTorch functions not currently available in Candle.
  • kalosm: A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
  • candle-sampling: Sampling techniques for Candle.

If you have an addition to this list, please submit a pull request.

Features

  • Simple syntax, looks and feels like PyTorch.
  • Backends.
    • Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
    • CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
    • WASM support, run your models in a browser.
  • Included models.
    • Language Models.
      • LLaMA v1 and v2 with variants such as SOLAR-10.7B.
      • Falcon.
      • StarCoder.
      • Phi 1, 1.5, and 2.
      • Minimal Mamba
      • Mistral 7b v0.1.
      • Mixtral 8x7b v0.1.
      • StableLM-3B-4E1T.
      • Replit-code-v1.5-3B.
      • Bert.
      • Yi-6B and Yi-34B.
    • Quantized LLMs.
      • Llama 7b, 13b, 70b, as well as the chat and code variants.
      • Mistral 7b, and 7b instruct.
      • Mixtral 8x7b.
      • Zephyr 7b a and b (Mistral-7b based).
      • OpenChat 3.5 (Mistral-7b based).
    • Text to text.
      • T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
      • Marian MT (Machine Translation).
    • Whisper (multi-lingual support).
    • Text to image.
      • Stable Diffusion v1.5, v2.1, XL v1.0.
      • Wurstchen v2.
    • Image to text.
      • BLIP.
    • Computer Vision Models.
      • DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
      • yolo-v3, yolo-v8.
      • Segment-Anything Model (SAM).
  • File formats: load models from safetensors, npz, ggml, or PyTorch files.
  • Serverless (on CPU), small and fast deployments.
  • Quantization support using the llama.cpp quantized types.

How to use

Cheatsheet:

Using PyTorch Using Candle
Creation torch.Tensor([[1, 2], [3, 4]]) Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?
Creation torch.zeros((2, 2)) Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?
Indexing tensor[:, :4] tensor.i((.., ..4))?
Operations tensor.view((2, 2)) tensor.reshape((2, 2))?
Operations a.matmul(b) a.matmul(&b)?
Arithmetic a + b &a + &b
Device tensor.to(device="cuda") tensor.to_device(&Device::new_cuda(0)?)?
Dtype tensor.to(dtype=torch.float16) tensor.to_dtype(&DType::F16)?
Saving torch.save({"A": A}, "model.bin") candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?
Loading weights = torch.load("model.bin") candle::safetensors::load("model.safetensors", &device)

Structure

FAQ

Why should I use Candle?

Candle's core goal is to make serverless inference possible. Full machine learning frameworks like PyTorch are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight binaries.

Secondly, Candle lets you remove Python from production workloads. Python overhead can seriously hurt performance, and the GIL is a notorious source of headaches.

Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like safetensors and tokenizers.

Other ML frameworks

  • dfdx is a formidable crate, with shapes being included in types. This prevents a lot of headaches by getting the compiler to complain about shape mismatches right off the bat. However, we found that some features still require nightly, and writing code can be a bit daunting for non rust experts.

    We're leveraging and contributing to other core crates for the runtime so hopefully both crates can benefit from each other.

  • burn is a general crate that can leverage multiple backends so you can choose the best engine for your workload.

  • tch-rs Bindings to the torch library in Rust. Extremely versatile, but they bring in the entire torch library into the runtime. The main contributor of tch-rs is also involved in the development of candle.

Common Errors

Missing symbols when compiling with the mkl feature.

If you get some missing symbols when compiling binaries/tests using the mkl or accelerate features, e.g. for mkl you get:

  = note: /usr/bin/ld: (....o): in function `blas::sgemm':
          .../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status

  = note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
  = note: use the `-l` flag to specify native libraries to link
  = note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo

or for accelerate:

Undefined symbols for architecture arm64:
            "_dgemm_", referenced from:
                candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
            "_sgemm_", referenced from:
                candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
          ld: symbol(s) not found for architecture arm64

This is likely due to a missing linker flag that was needed to enable the mkl library. You can try adding the following for mkl at the top of your binary:

extern crate intel_mkl_src;

or for accelerate:

extern crate accelerate_src;

Cannot run the LLaMA examples: access to source requires login credentials

Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401

This is likely because you're not permissioned for the LLaMA-v2 model. To fix this, you have to register on the huggingface-hub, accept the LLaMA-v2 model conditions, and set up your authentication token. See issue #350 for more details.

Missing cute/cutlass headers when compiling flash-attn

  In file included from kernels/flash_fwd_launch_template.h:11:0,
                   from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
  kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
   #include <cute/algorithm/copy.hpp>
            ^~~~~~~~~~~~~~~~~~~~~~~~~
  compilation terminated.
  Error: nvcc error while compiling:

cutlass is provided as a git submodule so you may want to run the following command to check it in properly.

git submodule update --init

Compiling with flash-attention fails

/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:

This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.

env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...

Linking error on windows when running rustdoc or mdbook tests

Couldn't compile the test.
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
error: linking with `link.exe` failed: exit code: 1181
//very long chain of linking
 = note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'

Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:

mdbook test candle-book -L .\target\debug\deps\ `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib

Extremely slow model load time with WSL

This may be caused by the models being loaded from /mnt/c, more details on stackoverflow.

Tracking down errors

You can set RUST_BACKTRACE=1 to be provided with backtraces when a candle error is generated.

About

Minimalist ML framework for Rust

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Rust 82.2%
  • Python 5.8%
  • HTML 3.6%
  • Cuda 3.2%
  • C++ 3.1%
  • Metal 1.1%
  • Other 1.0%