Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load model from redisai #539

Open
Duture opened this issue Nov 9, 2022 · 3 comments
Open

Load model from redisai #539

Duture opened this issue Nov 9, 2022 · 3 comments

Comments

@Duture
Copy link

Duture commented Nov 9, 2022

@auxten Hello,
I have trained the cifar10 with pytorch lenet and put weights to the redisai.
I load model from redisai and write to nodes of gorgonia.
By forward function I just get 10% accuracy.

@auxten
Copy link
Contributor

auxten commented Nov 9, 2022

I have never used redisai, to figure out what's going wrong, show your code maybe?

@Duture
Copy link
Author

Duture commented Nov 9, 2022

/* Thanks for your reply. I am a new hand in golang and deep learning.
// Following is my Lenet code. It do same thing like pytorch do.
*/

package cifar10

import (
	"log"

	"github.com/RedisAI/redisai-go/redisai"
	G "gorgonia.org/gorgonia"
	"gorgonia.org/tensor"
)

var dt = tensor.Float32

// redisai keys: a key => a tensor
var redisaiKeys = []string{
	"infty.lenet.conv.0.weight",
	"infty.lenet.conv.3.weight",
	"infty.lenet.fc.0.weight",
	"infty.lenet.fc.2.weight",
	"infty.lenet.fc.4.weight"}

type LeNet struct {
	g          *G.ExprGraph
	learnables G.Nodes
}

func NewLeNet(g *G.ExprGraph) *LeNet {

	lenet := LeNet{}
	lenet.g = g
	lenet.learnables = make(G.Nodes, 5)

	lenet.learnables[0] = G.NewTensor(g, dt, 4, G.WithShape(6, 3, 5, 5), G.WithName("w0"), G.WithInit(G.ValuesOf(float32(1))))
	lenet.learnables[1] = G.NewTensor(g, dt, 4, G.WithShape(16, 6, 5, 5), G.WithName("w1"), G.WithInit(G.ValuesOf(float32(1))))
	lenet.learnables[2] = G.NewMatrix(g, dt, G.WithShape(400, 120), G.WithName("w2"), G.WithInit(G.ValuesOf(float32(1))))
	lenet.learnables[3] = G.NewMatrix(g, dt, G.WithShape(120, 84), G.WithName("w3"), G.WithInit(G.ValuesOf(float32(1))))
	lenet.learnables[4] = G.NewMatrix(g, dt, G.WithShape(84, 10), G.WithName("w4"), G.WithInit(G.ValuesOf(float32(1))))

	// get pretrained weights from redisai and write to nodes
	client := redisai.Connect("redis://localhost:6379", nil)
	for i := 0; i < len(lenet.learnables); i++ {
		w, err := client.TensorGet(redisaiKeys[i], redisai.TensorContentTypeValues)
		if err == nil {
			wt := tensor.New(tensor.WithShape(lenet.learnables[i].Shape()...), tensor.WithBacking(w[2]))
			G.Let(lenet.learnables[i], wt)
		} else {
			log.Fatalln(err)
		}
	}
	return &lenet
}

func (m *LeNet) convFwd(x *G.Node, w *G.Node) *G.Node {
	c0 := G.Must(G.Conv2d(x, w, tensor.Shape{5, 5}, []int{0, 0}, []int{1, 1}, []int{1, 1}))
	a0 := G.Must(G.Rectify(c0))
	p0 := G.Must(G.MaxPool2D(a0, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}))
	return p0
}

func (m *LeNet) fcFwd(x *G.Node, w *G.Node) *G.Node {
	fc0 := G.Must(G.Mul(x, w))
	a0 := G.Must(G.Rectify(fc0))
	return a0
}

func (m *LeNet) fwd(x *G.Node) (out *G.Node, err error) {

	var l0, l1, r2, l3, l4, l5 *G.Node
	// Conv
	l0 = m.convFwd(x, m.learnables[0])
	l1 = m.convFwd(l0, m.learnables[1])

	// flattern
	b, c, h, w := l1.Shape()[0], l1.Shape()[1], l1.Shape()[2], l1.Shape()[3]
	r2 = G.Must(G.Reshape(l1, tensor.Shape{b, c * h * w}))

	// full connect
	l3 = m.fcFwd(r2, m.learnables[2])
	l4 = m.fcFwd(l3, m.learnables[3])
	l5 = m.fcFwd(l4, m.learnables[4])
	out, err = G.SoftMax(l5, 1)
	return out, err
}

@auxten
Copy link
Contributor

auxten commented Nov 9, 2022

First, I don't know if anyone tried filling the gorgonia graph with PyTorch trained weight. Since gorgonia never guaranteed the implementation here is identical to any other framework. But it will make gorgonia more popular if we do that.
Second, a quick check that makes sure you also use float32 in your PyTorch code.
Third, maybe you can upload your whole code to GitHub to make it easy to debug.

To clarify, I'm also quite new to gorgonia. Maybe @chewxy @owulveryck has a better suggestion?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants