Skip to content

marcosalvalaggio/kiwigrad

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Kiwigrad


Maintenance stability-wip

Despite lacking the ability to fly through the skies like PyTorch and TensorFlow, the Kiwigrad is still a formidable bird that is teeming with untapped potential waiting to be uncovered. 😉

Kiwigrad? yes, it is another version of micrograd that was created just for fun and experimentation.

Install

To install the current release,

pip install kiwigrad==0.28

Functionalities

Kiwigrad is a modified version of the micrograd and the minigrad packages with additional features. The main features added to Kiwigrad are:

  • Training is faster due to the C implementation of the Value object.
  • Tracing functionalities like the original micrograd package were added. An example of this can be seen in the ops notebook.
  • Methods for saving and loading the weights of a trained model.
  • Support for RNN(1) feedforward neural networks.

Examples

  • In the examples folder, you can find examples of models trained using the Kiwigrad library.
  • A declaration example of an MLP net using Kiwigrad:
from kiwigrad import MLP, Layer

class PotNet(MLP):
    def __init__(self):
        layers = [
            Layer(nin=2, nout=16, bias=True, activation="relu"),
            Layer(nin=16, nout=16, bias=True, activation="relu"),
            Layer(nin=16, nout=1, bias=True, activation="linear")
        ]
        super().__init__(layers=layers)

model = PotNet()
  • Kiwigrad like micrograd comes with support for a number of possible operations:
from kiwigrad import Value, draw_dot

a = Value(-4.0)
b = Value(2.0)
c = a + b
d = a * b + b**3
c += c + Value(1.)
c += Value(1.) + c + (-a)
d += d * Value(2) + (b + a).relu()
d += Value(3.) * d + (b - a).relu()
e = c - d
f = e**2
g = f / Value(2.0)
g += Value(10.0) / f
print(f'{g.data:.4f}') # prints 24.7041, the outcome of this forward pass
g.backward()
print(f'{a.grad:.4f}') # prints 138.8338, i.e. the numerical value of dg/da
print(f'{b.grad:.4f}') # prints 645.5773, i.e. the numerical value of dg/db

draw_dot(g)

Running test

cd test 
pytest .

Releases

No releases published

Packages

No packages published