-
Notifications
You must be signed in to change notification settings - Fork 2
/
service_test.go
65 lines (55 loc) · 1.39 KB
/
service_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
package gotensor_test
import (
"io/ioutil"
"testing"
"github.com/helinwang/gotensor"
"github.com/stretchr/testify/require"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
func TestFetch(t *testing.T) {
// a graph of b = a + 1
graph, err := ioutil.ReadFile("./test_data/a_plus_1.pb")
require.Nil(t, err)
s, err := gotensor.New(graph)
require.Nil(t, err)
var resp gotensor.Response
tensor, err := tf.NewTensor(int32(2))
require.Nil(t, err)
err = s.Run(gotensor.Request{
Feeds: []gotensor.Feed{
gotensor.Feed{
Edge: gotensor.Edge{OpName: "a"},
Tensor: gotensor.Tensor{tensor},
},
},
Fetches: []gotensor.Edge{
gotensor.Edge{OpName: "b"},
},
}, &resp)
require.Nil(t, err)
require.Equal(t, "", resp.Error)
require.Equal(t, int32(3), resp.Outputs[0].Value())
}
func TestTarget(t *testing.T) {
// a graph of b = a + 1
graph, err := ioutil.ReadFile("./test_data/a_plus_1.pb")
require.Nil(t, err)
s, err := gotensor.New(graph)
require.Nil(t, err)
var resp gotensor.Response
tensor, err := tf.NewTensor(int32(2))
require.Nil(t, err)
err = s.Run(gotensor.Request{
Feeds: []gotensor.Feed{
gotensor.Feed{
Edge: gotensor.Edge{OpName: "a"},
Tensor: gotensor.Tensor{tensor},
},
},
Targets: []string{"b"},
}, &resp)
require.Nil(t, err)
require.Equal(t, "", resp.Error)
// no fetch is required, so no output.
require.Equal(t, 0, len(resp.Outputs))
}