/
remote_common.hpp
255 lines (212 loc) · 7.67 KB
/
remote_common.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
// common code for ``remote'' MPI boundary conditions for libmpdata++
//
// licensing: GPU GPL v3
// copyright: University of Warsaw
#pragma once
#include <libmpdata++/bcond/detail/bcond_common.hpp>
#if defined(USE_MPI)
# include <boost/serialization/vector.hpp>
# include <boost/mpi/communicator.hpp>
# include <boost/mpi/nonblocking.hpp>
#endif
namespace libmpdataxx
{
namespace bcond
{
namespace detail
{
template <typename real_t, int halo, drctn_e dir, int n_dims>
class remote_common : public detail::bcond_common<real_t, halo, n_dims>
{
using parent_t = detail::bcond_common<real_t, halo, n_dims>;
protected:
using arr_t = blitz::Array<real_t, n_dims>;
using idx_t = blitz::RectDomain<n_dims>;
real_t *buf_send,
*buf_recv;
private:
#if defined(USE_MPI)
boost::mpi::communicator mpicom;
# if defined(NDEBUG)
static const int n_reqs = 2; // data, reqs for recv only is enough?
static const int n_dbg_send_reqs = 0;
static const int n_dbg_tags = 0;
# else
static const int n_reqs = 4; // data + ranges
static const int n_dbg_send_reqs = 1;
static const int n_dbg_tags = 2;
# endif
std::array<boost::mpi::request, n_reqs> reqs;
const int peer = dir == left
? (mpicom.rank() - 1 + mpicom.size()) % mpicom.size()
: (mpicom.rank() + 1 ) % mpicom.size();
# if !defined(NDEBUG)
std::pair<int, int> buf_rng;
# endif
#endif
protected:
const bool is_cyclic =
#if defined(USE_MPI)
(dir == left && mpicom.rank() == 0) ||
(dir == rght && mpicom.rank() == mpicom.size()-1);
#else
false;
#endif
void send_hlpr(
const arr_t &a,
const idx_t &idx_send
)
{
#if defined(USE_MPI)
// distinguishing between left and right messages
// (important e.g. with 2 procs and cyclic bc)
const int
msg_send = dir == left ? left : rght;
// std::cerr << "send_hlpr idx dir " << dir << " : "
// << " (" << idx_send.lbound(0) << ", " << idx_send.ubound(0) << ")"
// << " (" << idx_send.lbound(1) << ", " << idx_send.ubound(1) << ")"
// << " (" << idx_send.lbound(2) << ", " << idx_send.ubound(2) << ")"
// << std::endl;
// arr_send references part of the send buffer that will be used
arr_t arr_send(buf_send, a(idx_send).shape(), blitz::neverDeleteData);
// copying data to be sent
arr_send = a(idx_send);
// launching async data transfer
if(arr_send.size()!=0)
{
// use the pointer+size kind of send instead of serialization of blitz arrays, because
// serialization caused memory leaks, probably because it breaks blitz reference counting
reqs[0] = mpicom.isend(peer, msg_send, buf_send, arr_send.size());
// sending debug information
# if !defined(NDEBUG)
reqs[1] = mpicom.isend(peer, msg_send + n_dbg_tags, std::pair<int,int>(
idx_send[0].first(),
idx_send[0].last()
));
# endif
}
#else
assert(false);
#endif
};
void recv_hlpr(
const arr_t &a,
const idx_t &idx_recv
)
{
#if defined(USE_MPI)
const int
msg_recv = dir == left ? rght : left;
// std::cerr << "recv_hlpr idx dir " << dir << " : "
// << " (" << idx_recv.lbound(0) << ", " << idx_recv.ubound(0) << ")"
// << " (" << idx_recv.lbound(1) << ", " << idx_recv.ubound(1) << ")"
// << " (" << idx_recv.lbound(2) << ", " << idx_recv.ubound(2) << ")"
// << std::endl;
// launching async data transfer
if(a(idx_recv).size()!=0) // TODO: test directly size of idx_recv
{
reqs[1+n_dbg_send_reqs] = mpicom.irecv(peer, msg_recv, buf_recv, a(idx_recv).size());
// sending debug information
# if !defined(NDEBUG)
reqs[3] = mpicom.irecv(peer, msg_recv + n_dbg_tags, buf_rng);
# endif
}
#else
assert(false);
#endif
}
void send(
const arr_t &a,
const idx_t &idx_send
)
{
#if defined(USE_MPI)
send_hlpr(a, idx_send);
// waiting for the transfers to finish
boost::mpi::wait_all(reqs.begin(), reqs.begin() + 1 + n_dbg_send_reqs); // MPI_Waitall is thread-safe?
#else
assert(false);
#endif
}
void recv(
const arr_t &a,
const idx_t &idx_recv
)
{
#if defined(USE_MPI)
//auto arr_recv = recv_hlpr(a, idx_recv);
recv_hlpr(a, idx_recv);
// waiting for the transfers to finish
boost::mpi::wait_all(reqs.begin() + 1 + n_dbg_send_reqs, reqs.end()); // MPI_Waitall is thread-safe?
// a blitz handler for the used part of the receive buffer
arr_t arr_recv(buf_recv, a(idx_recv).shape(), blitz::neverDeleteData); // TODO: shape directly from idx_recv
// checking debug information
// positive modulo (grid_size_0 - 1)
// auto wrap = [this](int n) {return (n % (grid_size_0 - 1) + grid_size_0 - 1) % (grid_size_0 - 1);};
// assert(wrap(buf_rng.first) == wrap(idx_recv[0].first()));
// assert(wrap(buf_rng.second) == wrap(idx_recv[0].last()));
// writing received data to the array
a(idx_recv) = arr_recv;
#else
assert(false);
#endif
}
void xchng(
const arr_t &a,
const idx_t &idx_send,
const idx_t &idx_recv
)
{
#if defined(USE_MPI)
send_hlpr(a, idx_send);
recv_hlpr(a, idx_recv);
// waiting for the transfers to finish
boost::mpi::wait_all(reqs.begin(), reqs.end());
// a blitz handler for the used part of the receive buffer
arr_t arr_recv(buf_recv, a(idx_recv).shape(), blitz::neverDeleteData);
// checking debug information
// positive modulo (grid_size_0 - 1)
// auto wrap = [this](int n) {return (n % (grid_size_0 - 1) + grid_size_0 - 1) % (grid_size_0 - 1);};
// assert(wrap(buf_rng.first) == wrap(idx_recv[0].first()));
// assert(wrap(buf_rng.second) == wrap(idx_recv[0].last()));
// writing received data to the array
a(idx_recv) = arr_recv;
#else
assert(false);
#endif
}
public:
// ctor
remote_common(
const rng_t &i,
const std::array<int, n_dims> &distmem_grid_size,
bool single_threaded = false
) :
parent_t(i, distmem_grid_size, single_threaded)
{
#if defined(USE_MPI)
const int slice_size = n_dims==1 ? 1 : (n_dims==2? distmem_grid_size[1]+6 : (distmem_grid_size[1]+6) * (distmem_grid_size[2]+6) ); // 3 is the max halo size (?), so 6 on both sides
//std::cerr << "remote_common ctor, "
// << " distmem_grid_size[0]: " << distmem_grid_size[0]
// << " distmem_grid_size[1]: " << distmem_grid_size[1]
// << " distmem_grid_size[2]: " << distmem_grid_size[2]
// << " slice_size: " << slice_size
// << " halo: " << halo
// << std::endl;
// allocate enough memory in buffers to store largest halos to be sent
buf_send = (real_t *) malloc(halo * slice_size * sizeof(real_t));
buf_recv = (real_t *) malloc(halo * slice_size * sizeof(real_t));
#endif
}
// dtor
~remote_common()
{
#if defined(USE_MPI)
free(buf_send);
free(buf_recv);
#endif
}
};
}
} // namespace bcond
} // namespace libmpdataxx