Skip to content

Commit

Permalink
Merge pull request #129 from marty1885/apichange
Browse files Browse the repository at this point in the history
more indexing method
  • Loading branch information
marty1885 committed Jan 7, 2020
2 parents ffb5cf8 + 958af35 commit ed1cd88
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 24 deletions.
42 changes: 23 additions & 19 deletions Etaler/Core/Tensor.hpp
Expand Up @@ -211,15 +211,15 @@ struct ETALER_EXPORT Tensor
Tensor log() const { return backend()->log(pimpl()); }
Tensor logical_not() const { return backend()->logical_not(pimpl()); }

Tensor add(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->add(a(), b()); }
Tensor subtract(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->subtract(a(), b()); }
Tensor mul(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->mul(a(), b()); }
Tensor div(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->div(a(), b()); }
Tensor equal(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->equal(a(), b()); }
Tensor greater(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->greater(a(), b()); }
Tensor lesser(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->lesser(a(), b()); }
Tensor logical_and(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_and(a(), b()); }
Tensor logical_or(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_or(a(), b()); }
Tensor add(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->add(a.pimpl(), b.pimpl()); }
Tensor subtract(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->subtract(a.pimpl(), b.pimpl()); }
Tensor mul(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->mul(a.pimpl(), b.pimpl()); }
Tensor div(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->div(a.pimpl(), b.pimpl()); }
Tensor equal(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->equal(a.pimpl(), b.pimpl()); }
Tensor greater(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->greater(a.pimpl(), b.pimpl()); }
Tensor lesser(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->lesser(a.pimpl(), b.pimpl()); }
Tensor logical_and(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_and(a.pimpl(), b.pimpl()); }
Tensor logical_or(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_or(a.pimpl(), b.pimpl()); }

inline bool any() const { return cast(DType::Bool).sum(std::nullopt, DType::Bool).item<uint8_t>(); }
inline bool all() const { return cast(DType::Bool).sum(std::nullopt).item<int32_t>() == int32_t(size()); }
Expand Down Expand Up @@ -252,14 +252,14 @@ struct ETALER_EXPORT Tensor

//Subscription operator
Tensor operator [] (const IndexList& r) { return view(r); }
template <typename ... Args>
Tensor operator () (Args ... args) { return view({args ...}); }

Tensor sum(std::optional<intmax_t> dim=std::nullopt, DType dtype=DType::Unknown) const;
Tensor abs() const { return backend()->abs(pimpl()); }
bool isSame (const Tensor& other) const;

//Utils
TensorImpl* operator () () {return pimpl();}
const TensorImpl* operator () () const {return pimpl();}

using iterator = TensorIterator<Tensor>;
using const_iterator = TensorIterator<const Tensor>;
Expand Down Expand Up @@ -332,18 +332,18 @@ inline Tensor cellActivity(const Tensor& x, const Tensor& connections, const Ten
return x;
return x.cast(DType::Bool);
}();
return x.backend()->cellActivity(input(), connections(), permeances(), connected_permeance, active_threshold, has_unconnected_synapse);
return x.backend()->cellActivity(input.pimpl(), connections.pimpl(), permeances.pimpl(), connected_permeance, active_threshold, has_unconnected_synapse);
}

inline void learnCorrilation(const Tensor& x, const Tensor& learn, const Tensor& connection
, Tensor& permeances, float perm_inc, float perm_dec, bool has_unconnected_synapse=true)
{
x.backend()->learnCorrilation(x(), learn(), connection(), permeances(), perm_inc, perm_dec, has_unconnected_synapse);
x.backend()->learnCorrilation(x.pimpl(), learn.pimpl(), connection.pimpl(), permeances.pimpl(), perm_inc, perm_dec, has_unconnected_synapse);
}

inline Tensor globalInhibition(const Tensor& x, float fraction)
{
return x.backend()->globalInhibition(x(), fraction);
return x.backend()->globalInhibition(x.pimpl(), fraction);
}

Tensor inline cast(const Tensor& x, DType dtype)
Expand All @@ -358,27 +358,27 @@ inline Tensor copy(const Tensor& x)

inline void sortSynapse(Tensor& connection, Tensor& permeances)
{
connection.backend()->sortSynapse(connection(), permeances());
connection.backend()->sortSynapse(connection.pimpl(), permeances.pimpl());
}

inline Tensor burst(const Tensor& x, const Tensor& s)
{
return x.backend()->burst(x(), s());
return x.backend()->burst(x.pimpl(), s.pimpl());
}

inline Tensor reverseBurst(const Tensor& x)
{
return x.backend()->reverseBurst(x());
return x.backend()->reverseBurst(x.pimpl());
}

inline void growSynapses(const Tensor& x, const Tensor& y, Tensor& connections, Tensor& permeances, float init_perm)
{
x.backend()->growSynapses(x(), y(), connections(), permeances(), init_perm);
x.backend()->growSynapses(x.pimpl(), y.pimpl(), connections.pimpl(), permeances.pimpl(), init_perm);
}

inline void decaySynapses(Tensor& connections, Tensor& permeances, float threshold)
{
connections.backend()->decaySynapses(connections(), permeances(), threshold);
connections.backend()->decaySynapses(connections.pimpl(), permeances.pimpl(), threshold);
}

inline void assign(Tensor& x, const Tensor& y)
Expand Down Expand Up @@ -420,6 +420,10 @@ inline Tensor logical_or(const Tensor& x1, const Tensor& x2) { return x1.logical
inline bool all(const Tensor& t) { return t.all(); }
inline bool any(const Tensor& t) { return t.any(); }

template <typename ... Args>
inline Tensor view(const Tensor& t, Args... args) { return t.view({args...}); }
inline Tensor dynamic_view(const Tensor& t, const IndexList& indices) { return t.view(indices); }

inline Tensor zeros_like(const Tensor& x) { return zeros(x.shape(), x.dtype(), x.backend()); }
inline Tensor ones_like(const Tensor& x) { return ones(x.shape(), x.dtype(), x.backend()); }
}
Expand Down
2 changes: 1 addition & 1 deletion docs/source/PythonBindings.md
@@ -1,7 +1,7 @@
# Python bindings

## PyEtaler
[PyEtaler](https://guthub.com/etaler/pyetaler) is the offical binding for Etaler. We try to keep the Python API as close to the C++ one as possible. So you can use the C++ document as the Python document. With that said, some functions are changed in the binding to make it more Pythonic.
[PyEtaler](https://github.com/etaler/pyetaler) is the offical binding for Etaler. We try to keep the Python API as close to the C++ one as possible. So you can use the C++ document as the Python document. With that said, some functions are changed in the binding to make it more Pythonic.

```python
>>> from etaler import et
Expand Down
12 changes: 8 additions & 4 deletions tests/common_tests.cpp
Expand Up @@ -9,7 +9,6 @@
#include <Etaler/Algorithms/SDRClassifer.hpp>

#include <numeric>
#include <execution>

using namespace et;

Expand Down Expand Up @@ -320,6 +319,14 @@ TEST_CASE("Testing Tensor", "[Tensor]")
CHECK((ones({4,4}) == t).any() == true);
CHECK((ones({4,4}) == t).all() == false);
}

SECTION("xtensor style views") {
CHECK(view(t, 2).isSame(t.view({2})));

IndexList lst;
lst.push_back(3);
CHECK(dynamic_view(t, lst).isSame(t.view({3})));
}
}

SECTION("item") {
Expand Down Expand Up @@ -994,11 +1001,8 @@ TEST_CASE("Complex Tensor operations")
// Test summing along the first dimension. Making sure iterator and sum() works
// Tho you should always use the sum() function instead of accumulate or reduce
Tensor t = std::accumulate(a.begin(), a.end(), zeros({a.shape()[1]}));
Tensor q = std::reduce(std::execution::par, a.begin(), a.end(), zeros({a.shape()[1]}));
Tensor a_sum = a.sum(0);
CHECK(t.isSame(a_sum));
CHECK(q.isSame(a_sum));
CHECK(t.isSame(q)); // Should be communicative
}

SECTION("generate") {
Expand Down

0 comments on commit ed1cd88

Please sign in to comment.