Skip to content

Commit

Permalink
Implement XTensor support in core.
Browse files Browse the repository at this point in the history
Adds support for `xt::xtensor`, both row and column major. We're missing
`xt::xarray` and views.
  • Loading branch information
1uc committed Apr 8, 2024
1 parent f169f38 commit 59f9986
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 7 deletions.
7 changes: 4 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,16 @@ mark_as_advanced(HIGHFIVE_SANITIZER)

# Check compiler cxx_std requirements
# -----------------------------------
set(HIGHFIVE_CXX_STANDARD_DEFAULT 14)

if(NOT DEFINED CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD ${HIGHFIVE_CXX_STANDARD_DEFAULT})
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
endif()

if(CMAKE_CXX_STANDARD EQUAL 98 OR CMAKE_CXX_STANDARD LESS 14)
message(FATAL_ERROR "HighFive needs to be compiled with at least C++14")
if(CMAKE_CXX_STANDARD EQUAL 98 OR CMAKE_CXX_STANDARD LESS ${HIGHFIVE_CXX_STANDARD_DEFAULT})
message(FATAL_ERROR "HighFive needs to be compiled with at least C++${HIGHFIVE_CXX_STANDARD_DEFAULT}")
endif()

add_compile_definitions(HIGHFIVE_CXX_STD=${CMAKE_CXX_STANDARD})
Expand Down
106 changes: 106 additions & 0 deletions include/highfive/xtensor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#pragma once

#include "bits/H5Inspector_decl.hpp"
#include "H5Exception.hpp"

#include <xtensor/xtensor.hpp>
#include <xtensor/xarray.hpp>
#include <xtensor/xadapt.hpp>

namespace HighFive {
namespace details {

template <class XTensorType, bool IsConstExprRowMajor>
struct xtensor_inspector {
using type = XTensorType;
using value_type = typename type::value_type;
using base_type = typename inspector<value_type>::base_type;
using hdf5_type = base_type;

// TODO prevent non-trivial elements.

static constexpr size_t ndim = type::rank;
static constexpr size_t recursive_ndim = ndim + inspector<value_type>::recursive_ndim;
static constexpr bool is_trivially_copyable = IsConstExprRowMajor &&
std::is_trivially_copyable<value_type>::value &&
inspector<value_type>::is_trivially_copyable;

static std::vector<size_t> getDimensions(const type& val) {
std::array<size_t, ndim> shape = val.shape();
std::vector<size_t> sizes(shape.begin(), shape.end());
return sizes;
}

static void prepare(type& val, const std::vector<size_t>& dims) {
auto shape = std::array<size_t, ndim>{};
std::copy(dims.begin(), dims.begin() + ndim, shape.begin());
val.resize(shape);
}

static hdf5_type* data(type& val) {
if (!is_trivially_copyable) {
throw DataSetException("Invalid used of `inspector<xtensor>::data`.");
}

if (val.size() == 0) {
throw DataSetException("Invalid use of `inspector<xtensor>::data` for empty array.");
}

return inspector<value_type>::data(*val.data());
}

static const hdf5_type* data(const type& val) {
if (!is_trivially_copyable) {
throw DataSetException("Invalid used of `inspector<xtensor>::data`.");
}

if (val.size() == 0) {
throw DataSetException("Invalid use of `inspector<xtensor>::data` for empty array.");
}

return inspector<value_type>::data(*val.data());
}

static void serialize(const type& val, const std::vector<size_t>& dims, hdf5_type* m) {
auto shape = std::array<size_t, ndim>{};
std::copy(dims.begin(), dims.begin() + ndim, shape.begin());
size_t size = compute_total_size(dims);
xt::adapt(m, size, xt::no_ownership(), shape) = val;
}

static void unserialize(const hdf5_type* vec_align,
const std::vector<size_t>& dims,
type& val) {
std::array<size_t, ndim> shape;
std::copy(dims.begin(), dims.begin() + ndim, shape.begin());
size_t size = compute_total_size(dims);
val = xt::adapt(vec_align, size, xt::no_ownership(), shape);
}
};

template <typename T, size_t N>
struct inspector<xt::xtensor<T, N>>: public xtensor_inspector<xt::xtensor<T, N>, true> {
private:
using super = xtensor_inspector<xt::xtensor<T, N>, true>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};

template <typename T, size_t N, xt::layout_type L>
struct inspector<xt::xtensor<T, N, L>>: public xtensor_inspector<xt::xtensor<T, N, L>, false> {
private:
using super = xtensor_inspector<xt::xtensor<T, N, L>, false>;

public:
using type = typename super::type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
using hdf5_type = typename super::hdf5_type;
};

} // namespace details
} // namespace HighFive
87 changes: 83 additions & 4 deletions tests/unit/data_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
#include <highfive/eigen.hpp>
#endif

#ifdef HIGHFIVE_TEST_XTENSOR
#include <highfive/xtensor.hpp>
#endif


namespace HighFive {
namespace testing {

std::vector<size_t> lstrip(const std::vector<size_t>& indices, size_t n) {
template <class Dims>
std::vector<size_t> lstrip(const Dims& indices, size_t n) {
std::vector<size_t> subindices(indices.size() - n);
for (size_t i = 0; i < subindices.size(); ++i) {
subindices[i] = indices[i + n];
Expand All @@ -30,7 +35,8 @@ std::vector<size_t> lstrip(const std::vector<size_t>& indices, size_t n) {
return subindices;
}

size_t ravel(std::vector<size_t>& indices, const std::vector<size_t> dims) {
template <class Dims>
size_t ravel(std::vector<size_t>& indices, const Dims& dims) {
size_t rank = dims.size();
size_t linear_index = 0;
size_t ld = 1;
Expand All @@ -43,7 +49,8 @@ size_t ravel(std::vector<size_t>& indices, const std::vector<size_t> dims) {
return linear_index;
}

std::vector<size_t> unravel(size_t flat_index, const std::vector<size_t> dims) {
template <class Dims>
std::vector<size_t> unravel(size_t flat_index, const Dims& dims) {
size_t rank = dims.size();
size_t ld = 1;
std::vector<size_t> indices(rank);
Expand All @@ -56,7 +63,8 @@ std::vector<size_t> unravel(size_t flat_index, const std::vector<size_t> dims) {
return indices;
}

static size_t flat_size(const std::vector<size_t>& dims) {
template <class Dims>
static size_t flat_size(const Dims& dims) {
size_t n = 1;
for (auto d: dims) {
n *= d;
Expand Down Expand Up @@ -332,6 +340,7 @@ struct ContainerTraits<boost::numeric::ublas::matrix<T>> {

#endif

// -- Eigen -------------------------------------------------------------------
#if HIGHFIVE_TEST_EIGEN

template <typename EigenType>
Expand Down Expand Up @@ -468,6 +477,76 @@ struct ContainerTraits<Eigen::Map<PlainObjectType, MapOptions>>
};


#endif

// -- XTensor -----------------------------------------------------------------
#if HIGHFIVE_TEST_XTENSOR
template <typename XTensorType>
struct XTensorContainerTraits {
using container_type = XTensorType;
using value_type = typename container_type::value_type;
using base_type = typename ContainerTraits<value_type>::base_type;

static constexpr size_t rank = container_type::rank;
static constexpr bool is_view = ContainerTraits<value_type>::is_view;

static void set(container_type& array,
const std::vector<size_t>& indices,
const base_type& value) {
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
return ContainerTraits<value_type>::set(array[local_indices], lstrip(indices, rank), value);
}

static base_type get(const container_type& array, const std::vector<size_t>& indices) {
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
return ContainerTraits<value_type>::get(array[local_indices], lstrip(indices, rank));
}

static void assign(container_type& dst, const container_type& src) {
dst = src;
}

static container_type allocate(const std::vector<size_t>& dims) {
auto local_dims = std::array<size_t, rank>{};
std::copy(dims.begin(), dims.begin() + rank, local_dims.begin());
auto array = container_type(local_dims);

size_t n_elements = flat_size(local_dims);
for (size_t i = 0; i < n_elements; ++i) {
auto element = ContainerTraits<value_type>::allocate(lstrip(dims, rank));
set(array, unravel(i, local_dims), element);
}

return array;
}

static void deallocate(container_type& array, const std::vector<size_t>& dims) {
auto local_dims = std::vector<size_t>(dims.begin(), dims.begin() + rank);
size_t n_elements = flat_size(local_dims);
for (size_t i_flat = 0; i_flat < n_elements; ++i_flat) {
auto indices = unravel(i_flat, local_dims);
std::vector<size_t> local_indices(indices.begin(), indices.begin() + rank);
ContainerTraits<value_type>::deallocate(array[local_indices], lstrip(dims, rank));
}
}

static void sanitize_dims(std::vector<size_t>& dims, size_t axis) {
ContainerTraits<value_type>::sanitize_dims(dims, axis + rank);
}
};

template <class T, size_t rank, xt::layout_type layout>
struct ContainerTraits<xt::xtensor<T, rank, layout>>
: public XTensorContainerTraits<xt::xtensor<T, rank, layout>> {
private:
using super = XTensorContainerTraits<xt::xtensor<T, rank, layout>>;

public:
using container_type = typename super::container_type;
using value_type = typename super::value_type;
using base_type = typename super::base_type;
};

#endif

template <class T, class C>
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/supported_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ struct EigenMapMatrix {
};
#endif

#ifdef HIGHFIVE_TEST_XTENSOR
template <size_t rank, xt::layout_type layout, class C = type_identity>
struct XTensor {
template <class T>
using type = xt::xtensor<typename C::template type<T>, rank, layout>;
};
#endif

template <class C, class Tuple>
struct ContainerProduct;

Expand Down Expand Up @@ -150,6 +158,10 @@ using supported_array_types = typename ConcatenateTuples<
typename ContainerProduct<STDArray<7, EigenMatrix<3, 5, Eigen::RowMajor>>, scalar_types_eigen>::type,
typename ContainerProduct<STDArray<7, EigenArray<Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>, scalar_types_eigen>::type,
std::tuple<std::array<Eigen::VectorXd, 7>>,
#endif
#ifdef HIGHFIVE_TEST_XTENSOR
typename ContainerProduct<XTensor<3, xt::layout_type::row_major>, scalar_types_eigen>::type,
typename ContainerProduct<XTensor<3, xt::layout_type::column_major>, scalar_types_eigen>::type,
#endif
typename ContainerProduct<STDVector<>, all_scalar_types>::type,
typename ContainerProduct<STDVector<STDVector<>>, some_scalar_types>::type,
Expand Down

0 comments on commit 59f9986

Please sign in to comment.