Google's MobileNetV3 implemented using tch-rs framework, a Rust-binding of libtorch. MobileNetV3 model design was proposed on arXiv, and this project takes pytorch-mobilenet-v3 as reference implementation.
Make sure your stable Rust and Cargo are ready, and follow cargo build
to compile project. It's recommended to check out tch-rs' README to download pre-built libtorch binaries to speed up build process.
This project provides a demo training executable on MNIST and CIFAR-10. The library interface is also available if you would like to integrate with your project.
- CIFAR-10: Download binary version from CIFAR site, and run:
cargo run -- --dataset-name cifar-10 --dataset-dir /path/to/cifar-10-dir
- MNIST: Download and unpack all gzips to a directory from MNIST site, and run:
cargo run -- --dataset-name mnist --dataset-dir /path/to/mnist-dir
Here is example usage. It's suggested to visit our source code to understand details.
// model init
let mut vs = VarStore::new(Device::Cuda(0));
let root = vs.root();
let model = MobileNetV3::new(
&root / "mobilenetv3",
input_channel,
n_classes,
dropout,
width_mult, // usually 1.0
Mode::Large,
)?;
let opt = Adam::default().build(&vs, learning_rate)?;
// training
let logits = model.forward_t(&images, true);
let loss = prediction_logits.cross_entropy_for_logits(&labels);
opt.backward_step(&loss);
MIT, see LICENSE file