Skip to content

offdroid/rprop-tch-rs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Rprop for tch-rs

Resilient Propagation for tch-rs ported from PyTorch, specifically torch.optim.Rprop.

⚠️ Currently only tested with simple models!

Licensed under the same terms as PyTorch, see LICENSE

Usage

Add to Cargo.toml

rprop-tch = { git = "https://github.com/offdroid/rprop-tch-rs.git" }

Usage matches tch::nn::Optimizer

let vs = tch::nn::VarStore::new(tch::Device::Cpu);
// Init model with `vs`
let net: &dyn tch::nn::Module = todo!();
// Build Rprop optimizer, here with default paramters
let mut opt = rprop_tch::Rprop::build_default(&vs, Some(0.01));
// Training loop
for epoch in 1..10 {
    let (x, y) = todo!();
    let loss: tch::Tensor = net.forward(&x).mse_loss(&y);
    // Use it like tch::nn::Optimizer
    opt.zero_grad();
    loss.backward();
    opt.step();
}

Example

Check examples and/or run

cargo run --example basic

Releases

No releases published

Packages

No packages published

Languages