Skip to content

jerry73204/mobilenet-v3-rs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MobileNetV3 in Rust

Google's MobileNetV3 implemented using tch-rs framework, a Rust-binding of libtorch. MobileNetV3 model design was proposed on arXiv, and this project takes pytorch-mobilenet-v3 as reference implementation.

Usage

Build

Make sure your stable Rust and Cargo are ready, and follow cargo build to compile project. It's recommended to check out tch-rs' README to download pre-built libtorch binaries to speed up build process.

This project provides a demo training executable on MNIST and CIFAR-10. The library interface is also available if you would like to integrate with your project.

Run demo training

  • CIFAR-10: Download binary version from CIFAR site, and run:
cargo run -- --dataset-name cifar-10 --dataset-dir /path/to/cifar-10-dir
  • MNIST: Download and unpack all gzips to a directory from MNIST site, and run:
cargo run -- --dataset-name mnist --dataset-dir /path/to/mnist-dir

Use as library

Here is example usage. It's suggested to visit our source code to understand details.

// model init
let mut vs = VarStore::new(Device::Cuda(0));
let root = vs.root();
let model = MobileNetV3::new(
    &root / "mobilenetv3",
    input_channel,
    n_classes,
    dropout,
    width_mult,  // usually 1.0
    Mode::Large,
)?;
let opt = Adam::default().build(&vs, learning_rate)?;

// training
let logits = model.forward_t(&images, true);
let loss = prediction_logits.cross_entropy_for_logits(&labels);
opt.backward_step(&loss);

License

MIT, see LICENSE file

About

MobileNetV3 in Rust

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages