/
solver_common.hpp
385 lines (315 loc) · 12.6 KB
/
solver_common.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
/** @file
* @copyright University of Warsaw
* @section LICENSE
* GPLv3+ (see the COPYING file or http://www.gnu.org/licenses/)
*/
#pragma once
#include <libmpdata++/blitz.hpp>
#include <libmpdata++/formulae/arakawa_c.hpp>
#include <libmpdata++/concurr/detail/sharedmem.hpp>
#include <libmpdata++/solvers/detail/monitor.hpp>
#include <libmpdata++/bcond/detail/bcond_common.hpp>
#include <array>
namespace libmpdataxx
{
namespace solvers
{
namespace detail
{
using namespace libmpdataxx::arakawa_c;
constexpr int max(const int a, const int b)
{
return a > b ? a : b;
}
template <typename ct_params_t, int n_tlev_, int minhalo>
class solver_common
{
public:
enum { n_eqns = ct_params_t::n_eqns };
enum { halo = minhalo };
enum { n_dims = ct_params_t::n_dims };
enum { n_tlev = n_tlev_ };
using ct_params_t_ = ct_params_t; // propagate ct_params_t mainly for output purposes
using real_t = typename ct_params_t::real_t;
typedef blitz::Array<real_t, n_dims> arr_t;
using bcp_t = std::unique_ptr<bcond::detail::bcond_common<real_t, halo, n_dims>>;
using ix = typename ct_params_t::ix;
using advance_arg_t = typename std::conditional<ct_params_t::var_dt, real_t, int>::type;
protected:
// TODO: output common doesnt know about ct_params_t
static constexpr bool var_dt = ct_params_t::var_dt;
// for convenience
static constexpr bool div3_mpdata = opts::isset(ct_params_t::opts, opts::div_3rd) ||
opts::isset(ct_params_t::opts, opts::div_3rd_dt) ;
std::array<std::array<bcp_t, 2>, n_dims> bcs;
const int rank;
// di, dj, dk declared here for output purposes
real_t dt, di, dj, dk, max_abs_div_eps, max_courant;
std::array<real_t, div3_mpdata ? 2 : 1> dt_stash;
std::array<real_t, n_dims> dijk;
const idx_t<n_dims> ijk;
long long int timestep = 0;
real_t time = 0;
std::vector<int> n;
typedef concurr::detail::sharedmem<real_t, n_dims, n_tlev> mem_t;
mem_t *mem;
// helper methods invoked by solve()
virtual void advop(int e) = 0;
// helper method telling us if equation e is the last one advected assuming increasing order,
// but taking into account possible delay of advection of some equations
// and assuming that is_last_eqn is not called for delayed equations before it's called for non-delayed equations
constexpr bool is_last_eqn(int e)
{
return
(!opts::most_significant(ct_params_t::delayed_step) && e == n_eqns-1) || // no equations with delayed step
(e == opts::most_significant(ct_params_t::delayed_step)-1); // last of the delayed equations
}
virtual void cycle(int e) final
{
n[e] = (n[e] + 1) % n_tlev - n_tlev; // -n_tlev so that n+1 does not give out of bounds
if(is_last_eqn(e)) mem->cycle(rank);
}
virtual void xchng(int e) = 0;
// TODO: implement flagging of valid/invalid halo for optimisations
virtual void xchng_vctr_alng(arrvec_t<arr_t>&, const bool ad = false, const bool cyclic = false) = 0;
void set_bcs(const int &d, bcp_t &bcl, bcp_t &bcr)
{
// with distributed memory and cyclic boundary conditions,
// leftmost node must send left first, as
// rightmost node is waiting
if (d == 0 && this->mem->distmem.size() > 0 && this->mem->distmem.rank() == 0)
std::swap(bcl, bcr);
bcs[d][0] = std::move(bcl);
bcs[d][1] = std::move(bcr);
}
virtual real_t courant_number(const arrvec_t<arr_t>&) = 0;
virtual real_t max_abs_vctr_div(const arrvec_t<arr_t>&) = 0;
// return false if advector does not change in time
virtual bool calc_gc() {return false;}
// used to calculate nondimensionalised first and second time derivatives of advector
virtual void calc_ndt_gc() {}
virtual void scale_gc(const real_t time, const real_t cur_dt, const real_t prev_dt) = 0;
void solve_loop_body(const int e)
{
scale(e, ct_params_t::hint_scale(e));
xchng(e);
advop(e);
if(!is_last_eqn(e))
mem->barrier();
cycle(e); // note: assuming ascending order, mem->cycle is done after the lest eqn
scale(e, -ct_params_t::hint_scale(e));
}
// thread-aware range extension, messes range guessing in remote_3d bcond
template <class n_t>
rng_t extend_range(const rng_t &r, const n_t n) const
{
if (mem->size == 1) return r^n;
return rank == 0 ? rng_t((r - n).first(), r.last()) :
rank == mem->size - 1 ? rng_t(r.first(), (r + n).last()) :
r;
}
// thread-aware range extension, variadic version
template <class n_t, class... ns_t>
rng_t extend_range(const rng_t &r, const n_t n, const ns_t... ns) const
{
return extend_range(extend_range(r, n), ns...);
}
private:
#if !defined(NDEBUG)
bool
hook_ante_step_called = true, // initially true to handle nt=0
hook_ante_delayed_step_called = true,
hook_post_step_called = true,
hook_ante_loop_called = true;
#endif
protected:
virtual void hook_ante_step()
{
// sanity check if all subclasses call their parents' hooks
#if !defined(NDEBUG)
hook_ante_step_called = true;
#endif
}
virtual void hook_ante_delayed_step()
{
// sanity check if all subclasses call their parents' hooks
#if !defined(NDEBUG)
hook_ante_delayed_step_called = true;
#endif
}
virtual void hook_post_step()
{
#if !defined(NDEBUG)
hook_post_step_called = true;
#endif
}
virtual void hook_ante_loop(const advance_arg_t nt)
{
#if !defined(NDEBUG)
hook_ante_loop_called = true;
#endif
// fill halos in velocity field
this->xchng_vctr_alng(mem->GC);
// adaptive timestepping - for constant in time velocity it suffices
// to change the timestep once and do a simple scaling of advector
if (ct_params_t::var_dt)
{
real_t cfl = courant_number(mem->GC);
if (cfl > 0)
{
auto prev_dt = dt;
dt *= max_courant / cfl;
scale_gc(time, dt, prev_dt);
}
}
}
public:
const real_t time_() const { return time;}
struct rt_params_t
{
std::array<int, n_dims> grid_size;
real_t dt=0, max_abs_div_eps = blitz::epsilon(real_t(44)), max_courant = real_t(0.5);
};
// ctor
solver_common(
const int &rank,
mem_t *mem,
const rt_params_t &p,
const decltype(ijk) &ijk
) :
rank(rank),
dt_stash{},
dt(p.dt),
di(0),
dj(0),
dk(0),
max_abs_div_eps(p.max_abs_div_eps),
max_courant(p.max_courant),
n(n_eqns, 0),
mem(mem),
ijk(ijk)
{
// compile-time sanity checks
static_assert(n_eqns > 0, "!");
// run-time sanity checks
for (int d = 0; d < n_dims; ++d)
if (p.grid_size[d] < 1)
throw std::runtime_error("libmpdata++: bogus grid size");
}
// dtor
virtual ~solver_common()
{
#if defined(USE_MPI)
// finalize mpi if it was initialized by distmem,
// otherwise it would break programs that instantiate many solvers;
// TODO: MPI standard requires that the same thread that called mpi_init
// calls mpi_finalize, we don't ensure it
if(!libmpdataxx::concurr::detail::mpi_initialized_before && rank==0)
MPI_Finalize();
#endif
#if !defined(NDEBUG)
assert(hook_ante_step_called && "any overriding hook_ante_step() must call parent_t::hook_ante_step()");
assert(hook_post_step_called && "any overriding hook_post_step() must call parent_t::hook_post_step()");
assert(hook_ante_loop_called && "any overriding hook_ante_loop() must call parent_t::hook_ante_loop()");
assert(hook_ante_delayed_step_called && "any overriding hook_ante_delayed_step() must call parent_t::hook_ante_delayed_step()");
#endif
}
virtual void solve(advance_arg_t nt) final
{
// multiple calls to sovlve() are meant to advance the solution by nt
// TODO: does it really work with var_dt ? we do not advance by time exactly ...
nt += ct_params_t::var_dt ? time : timestep;
// being generous about out-of-loop barriers
if (timestep == 0)
{
mem->barrier();
#if !defined(NDEBUG)
hook_ante_loop_called = false;
#endif
hook_ante_loop(nt);
mem->barrier();
}
// moved here so that if an exception is thrown from hook_ante_loop these do not cause complaints
#if !defined(NDEBUG)
hook_ante_step_called = false;
hook_post_step_called = false;
hook_ante_delayed_step_called = false;
#endif
// higher-order temporal interpolation for output requires doing a few additional steps
int additional_steps = ct_params_t::out_intrp_ord;
while (ct_params_t::var_dt ? (time < nt || additional_steps > 0) : timestep < nt)
{
// progress-bar info through thread name (check top -H)
monitor(float(ct_params_t::var_dt ? time : timestep) / nt); // TODO: does this value make sanse with repeated advence() calls?
// might be used to implement multi-threaded signal handling
mem->barrier();
if (mem->panic) break;
// proper solver stuff
// for variable in time velocity calculate advector at n+1/2, returns false if
// velocity does not change in time
bool var_gc = calc_gc();
// for variable in time velocity with adaptive time-stepping modify advector
// to keep the Courant number roughly constant
if (var_gc && ct_params_t::var_dt)
{
real_t cfl = courant_number(mem->GC);
if (cfl > 0)
{
do
{
dt *= max_courant / cfl;
calc_gc();
cfl = courant_number(mem->GC);
}
while (cfl > max_courant);
}
}
// once we set the time step
// for third-order MPDATA we need to calculate time derivatives of the advector field
if (var_gc && div3_mpdata) calc_ndt_gc();
hook_ante_step();
for (int e = 0; e < n_eqns; ++e)
{
if (opts::isset(ct_params_t::delayed_step, opts::bit(e))) continue;
solve_loop_body(e);
}
hook_ante_delayed_step();
for (int e = 0; e < n_eqns; ++e)
{
if (!opts::isset(ct_params_t::delayed_step, opts::bit(e))) continue;
solve_loop_body(e);
}
timestep++;
time = ct_params_t::var_dt ? time + dt : timestep * dt;
if (div3_mpdata) dt_stash[1] = dt_stash[0];
dt_stash[0] = dt;
hook_post_step();
if (time >= nt) additional_steps--;
}
mem->barrier();
// note: hook_post_loop was removed as conficling with multiple-advance()-call logic
}
protected:
// psi[n] getter - just to shorten the code
// note that e.g. in hook_post_loop it points rather to
// psi^{n+1} than psi^{n} (hence not using the name psi_n)
virtual arr_t &state(const int &e) final
{
return mem->psi[e][n[e]];
}
static rng_t rng_vctr(const rng_t &rng) { return rng^h^(halo-1); }
static rng_t rng_sclr(const rng_t &rng) { return rng^halo; }
private:
void scale(const int &e, const int &exp)
{
if (exp == 0) return;
else if (exp > 0) state(e)(ijk) /= (1 << exp);
else if (exp < 0) state(e)(ijk) *= (1 << -exp);
}
};
template<typename ct_params_t, int n_tlev, int minhalo, class enableif = void>
class solver
{};
} // namespace detail
} // namespace solvers
} // namespace libmpdataxx