-
Notifications
You must be signed in to change notification settings - Fork 2
/
tensor_test.go
50 lines (41 loc) · 986 Bytes
/
tensor_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package gotensor_test
import (
"bytes"
"encoding/gob"
"testing"
"github.com/helinwang/gotensor"
"github.com/stretchr/testify/require"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
func TestGob(t *testing.T) {
testCase := []interface{}{
float32(2), float64(1),
int8(3), int16(3), int32(4), int64(5),
uint8(3), uint16(3),
complex(100, 8),
"string",
// unsupported:
// uint32, uint64
}
for _, v := range testCase {
tensor, err := tf.NewTensor(v)
require.Nil(t, err)
t0 := gotensor.Tensor{tensor}
require.Equal(t, v, t0.Value())
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err = enc.Encode(t0)
require.Nil(t, err)
err = enc.Encode(&t0)
require.Nil(t, err)
var t1 gotensor.Tensor
dec := gob.NewDecoder(bytes.NewReader(buf.Bytes()))
err = dec.Decode(&t1)
require.Nil(t, err)
require.Equal(t, v, t1.Value())
var t2 gotensor.Tensor
err = dec.Decode(&t2)
require.Nil(t, err)
require.Equal(t, v, t2.Value())
}
}