-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1633 from LLNL/feature/burmark1/for_each
Add for_each and for_each_type
- Loading branch information
Showing
4 changed files
with
251 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/*! | ||
****************************************************************************** | ||
* | ||
* \file | ||
* | ||
* \brief Header file providing RAJA for_each templates. | ||
* | ||
****************************************************************************** | ||
*/ | ||
|
||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// | ||
// Copyright (c) 2016-24, Lawrence Livermore National Security, LLC | ||
// and RAJA project contributors. See the RAJA/LICENSE file for details. | ||
// | ||
// SPDX-License-Identifier: (BSD-3-Clause) | ||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// | ||
|
||
#ifndef RAJA_util_for_each_HPP | ||
#define RAJA_util_for_each_HPP | ||
|
||
#include "RAJA/config.hpp" | ||
|
||
#include <iterator> | ||
#include <type_traits> | ||
|
||
#include "camp/list.hpp" | ||
|
||
#include "RAJA/pattern/detail/algorithm.hpp" | ||
|
||
#include "RAJA/util/macros.hpp" | ||
#include "RAJA/util/types.hpp" | ||
|
||
namespace RAJA | ||
{ | ||
|
||
namespace detail | ||
{ | ||
|
||
// runtime loop applying func to each element in the range in order | ||
template<typename Iter, typename UnaryFunc> | ||
RAJA_HOST_DEVICE RAJA_INLINE | ||
UnaryFunc for_each(Iter begin, Iter end, UnaryFunc func) | ||
{ | ||
for (; begin != end; ++begin) { | ||
func(*begin); | ||
} | ||
|
||
return func; | ||
} | ||
|
||
// compile time expansion applying func to a each type in the list in order | ||
template <typename UnaryFunc, typename... Ts> | ||
RAJA_HOST_DEVICE RAJA_INLINE | ||
UnaryFunc for_each_type(camp::list<Ts...> const&, UnaryFunc func) | ||
{ | ||
// braced init lists are evaluated in order | ||
int seq_unused_array[] = {0, (func(Ts{}), 0)...}; | ||
RAJA_UNUSED_VAR(seq_unused_array); | ||
|
||
return func; | ||
} | ||
|
||
} // namespace detail | ||
|
||
|
||
/*! | ||
\brief Apply func to all the elements in the given range in order | ||
using a sequential for loop in O(N) operations and O(1) extra memory | ||
see https://en.cppreference.com/w/cpp/algorithm/for_each | ||
*/ | ||
template <typename Container, typename UnaryFunc> | ||
RAJA_HOST_DEVICE RAJA_INLINE | ||
concepts::enable_if_t<UnaryFunc, type_traits::is_range<Container>> | ||
for_each(Container&& c, UnaryFunc func) | ||
{ | ||
using std::begin; | ||
using std::end; | ||
|
||
return detail::for_each(begin(c), end(c), std::move(func)); | ||
} | ||
|
||
/*! | ||
\brief Apply func to each type in the given list in order | ||
using a compile-time expansion in O(N) operations and O(1) extra memory | ||
*/ | ||
template <typename UnaryFunc, typename... Ts> | ||
RAJA_HOST_DEVICE RAJA_INLINE | ||
UnaryFunc for_each_type(camp::list<Ts...> const& c, UnaryFunc func) | ||
{ | ||
return detail::for_each_type(c, std::move(func)); | ||
} | ||
|
||
} // namespace RAJA | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// | ||
// Copyright (c) 2016-24, Lawrence Livermore National Security, LLC | ||
// and RAJA project contributors. See the RAJA/LICENSE file for details. | ||
// | ||
// SPDX-License-Identifier: (BSD-3-Clause) | ||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// | ||
|
||
/// | ||
/// Source file containing unit tests for for_each | ||
/// | ||
|
||
#include "RAJA_test-base.hpp" | ||
|
||
#include "RAJA_unit-test-types.hpp" | ||
|
||
#include "camp/resource.hpp" | ||
|
||
#include <type_traits> | ||
#include <vector> | ||
#include <set> | ||
|
||
template<typename T> | ||
class ForEachUnitTest : public ::testing::Test {}; | ||
|
||
TYPED_TEST_SUITE(ForEachUnitTest, UnitIndexTypes); | ||
|
||
|
||
TYPED_TEST(ForEachUnitTest, EmptyRange) | ||
{ | ||
std::vector<TypeParam> numbers; | ||
|
||
std::vector<TypeParam> copies; | ||
RAJA::for_each(numbers, [&](TypeParam& number) { | ||
number += 1; | ||
copies.push_back(number); | ||
}); | ||
|
||
ASSERT_EQ(copies.size(), 0); | ||
ASSERT_EQ(numbers.size(), 0); | ||
} | ||
|
||
TYPED_TEST(ForEachUnitTest, VectorRange) | ||
{ | ||
std::vector<TypeParam> numbers; | ||
for (TypeParam i = 0; i < 13; ++i) { | ||
numbers.push_back(i); | ||
} | ||
|
||
std::vector<TypeParam> copies; | ||
RAJA::for_each(numbers, [&](TypeParam& number) { | ||
copies.push_back(number); | ||
number += 1; | ||
}); | ||
|
||
ASSERT_EQ(copies.size(), 13); | ||
for (TypeParam i = 0; i < 13; ++i) { | ||
ASSERT_EQ(numbers[i], copies[i]+1); | ||
} | ||
} | ||
|
||
TYPED_TEST(ForEachUnitTest, RajaSpanRange) | ||
{ | ||
std::vector<TypeParam> numbers; | ||
for (TypeParam i = 0; i < 11; ++i) { | ||
numbers.push_back(i); | ||
} | ||
|
||
std::vector<TypeParam> copies; | ||
RAJA::for_each(RAJA::make_span(numbers.data(), 11), [&](TypeParam& number) { | ||
copies.push_back(number); | ||
number += 1; | ||
}); | ||
|
||
ASSERT_EQ(copies.size(), 11); | ||
for (TypeParam i = 0; i < 11; ++i) { | ||
ASSERT_EQ(numbers[i], copies[i]+1); | ||
} | ||
} | ||
|
||
TYPED_TEST(ForEachUnitTest, SetRange) | ||
{ | ||
std::set<TypeParam> numbers; | ||
for (TypeParam i = 0; i < 6; ++i) { | ||
numbers.insert(i); | ||
} | ||
|
||
std::vector<TypeParam> copies; | ||
RAJA::for_each(numbers, [&](TypeParam const& number) { | ||
copies.push_back(number); | ||
}); | ||
|
||
ASSERT_EQ(copies.size(), 6); | ||
for (TypeParam i = 0; i < 6; ++i) { | ||
ASSERT_EQ(i, copies[i]); | ||
ASSERT_EQ(numbers.count(i), 1); | ||
} | ||
} | ||
|
||
|
||
TYPED_TEST(ForEachUnitTest, EmptyTypeList) | ||
{ | ||
using numbers = camp::list<>; | ||
|
||
std::vector<TypeParam> copies; | ||
RAJA::for_each_type(numbers{}, [&](auto number) { | ||
copies.push_back(number); | ||
}); | ||
|
||
ASSERT_EQ(copies.size(), 0); | ||
} | ||
|
||
|
||
template < typename T, T val > | ||
T get_num(std::integral_constant<T, val>) | ||
{ | ||
return val; | ||
} | ||
|
||
template < typename TypeParam, | ||
std::enable_if_t<std::is_integral<TypeParam>::value>* = nullptr > | ||
void run_int_type_test() | ||
{ | ||
using numbers = camp::list<std::integral_constant<TypeParam, 0>, | ||
std::integral_constant<TypeParam, 1>, | ||
std::integral_constant<TypeParam, 2>, | ||
std::integral_constant<TypeParam, 3>, | ||
std::integral_constant<TypeParam, 4>>; | ||
|
||
std::vector<TypeParam> copies; | ||
RAJA::for_each_type(numbers{}, [&](auto number) { | ||
copies.push_back(get_num(number)); | ||
}); | ||
|
||
ASSERT_EQ(copies.size(), 5); | ||
for (TypeParam i = 0; i < 5; ++i) { | ||
ASSERT_EQ(i, copies[i]); | ||
} | ||
} | ||
/// | ||
template < typename TypeParam, | ||
std::enable_if_t<!std::is_integral<TypeParam>::value>* = nullptr > | ||
void run_int_type_test() | ||
{ | ||
// ignore non-ints | ||
} | ||
|
||
TYPED_TEST(ForEachUnitTest, IntTypeList) | ||
{ | ||
run_int_type_test<TypeParam>(); | ||
} |