Skip to content

Commit

Permalink
Corrected field example (arrayfire#3369)
Browse files Browse the repository at this point in the history
Corrected buffer overflow in vector_field
Join now accepts any buffer as out array, as long as it is large enough.
  • Loading branch information
willyborn committed Mar 28, 2023
1 parent 6736e93 commit 9d72403
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 28 deletions.
41 changes: 23 additions & 18 deletions src/api/c/vector_field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ using arrayfire::common::step_round;
using detail::Array;
using detail::copy_vector_field;
using detail::createEmptyArray;
using detail::createValueArray;
using detail::forgeManager;
using detail::reduce;
using detail::transpose;
Expand All @@ -50,25 +51,29 @@ fg_chart setup_vector_field(fg_window window, const vector<af_array>& points,
vector<Array<T>> pnts;
vector<Array<T>> dirs;

for (unsigned i = 0; i < points.size(); ++i) {
pnts.push_back(getArray<T>(points[i]));
dirs.push_back(getArray<T>(directions[i]));
}

// Join for set up vector
dim4 odims(3, points.size());
Array<T> out_pnts = createEmptyArray<T>(odims);
Array<T> out_dirs = createEmptyArray<T>(odims);
detail::join(out_pnts, 1, pnts);
detail::join(out_dirs, 1, dirs);
Array<T> pIn = out_pnts;
Array<T> dIn = out_dirs;

// do transpose if required
if (transpose_) {
pIn = transpose<T>(pIn, false);
dIn = transpose<T>(dIn, false);
// 1D, 2D and 3D input arrays are allowed!!
// Dims of each input array is:
// transpose_==true --> [N,1|2|3,1,1]
// transpose_==false --> [1|2|3,N,1,1]
// Multiple input arrays are allowed to provide X,Y,Z separately
Array<T> out_pnts = getArray<T>(points[0]);
Array<T> out_dirs = getArray<T>(directions[0]);

if (points.size() > 1) {
// Combine X-axis, Y-axis (and Z-axis) into 1 array
dim4 odims(getInfo(points[0]).dims());
odims.dims[transpose_ ? 1 : 0] = points.size();
out_pnts = createEmptyArray<T>(odims);
out_dirs = createEmptyArray<T>(odims);
for (unsigned i = 0; i < points.size(); ++i) {
pnts.push_back(getArray<T>(points[i]));
dirs.push_back(getArray<T>(directions[i]));
}
detail::join(out_pnts, transpose_ ? 1 : 0, pnts);
detail::join(out_dirs, transpose_ ? 1 : 0, dirs);
}
Array<T> pIn = transpose_ ? transpose<T>(out_pnts, false) : out_pnts;
Array<T> dIn = transpose_ ? transpose<T>(out_dirs, false) : out_dirs;

ForgeManager& fgMngr = forgeManager();

Expand Down
8 changes: 8 additions & 0 deletions src/backend/cpu/join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ Array<T> join(const int dim, const Array<T> &first, const Array<T> &second) {
template<typename T>
void join(Array<T> &out, const int dim, const std::vector<Array<T>> &inputs) {
const dim_t n_arrays = inputs.size();
// check if out can accomodate the full join
dim4 jdims(out.dims());
for (auto &iArray : inputs) {
const dim4 &idims(iArray.dims());
for (int i = 0; i < AF_MAX_DIMS; ++i)
ARG_ASSERT(1, jdims.dims[i] >= idims.dims[i]);
jdims.dims[dim] -= idims.dims[dim];
}

std::vector<Array<T> *> input_ptrs(inputs.size());
std::transform(
Expand Down
14 changes: 12 additions & 2 deletions src/backend/cpu/kernel/join.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,17 @@ void join_append(T *out, const T *X, const af::dim4 &offset,
const dim_t xYZW = xZW + oy * xst[1];
const dim_t oYZW = oZW + (oy + offset[1]) * ost[1];

memcpy(out + oYZW + offset[0], X + xYZW, xdims[0] * sizeof(T));
if (ost[0] == 1 && xst[0] == 1) {
memcpy(out + oYZW + offset[0], X + xYZW,
xdims[0] * sizeof(T));
} else {
out += oYZW + offset[0];
X += xYZW;
for (dim_t ox = 0; ox < xdims[0];
ox++, out += ost[0], X += xst[0]) {
*out = *X;
}
}
}
}
}
Expand All @@ -48,7 +58,7 @@ void join_append(T *out, const T *X, const af::dim4 &offset,
template<typename T>
void join(const int dim, Param<T> out, const std::vector<CParam<T>> inputs,
int n_arrays) {
af::dim4 zero(0, 0, 0, 0);
const af::dim4 zero(0, 0, 0, 0);
af::dim4 d = zero;
join_append<T>(out.get(), inputs[0].get(), zero, inputs[0].dims(),
out.strides(), inputs[0].strides());
Expand Down
25 changes: 20 additions & 5 deletions src/backend/cuda/join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ Array<T> join(const int jdim, const Array<T> &first, const Array<T> &second) {

template<typename T>
void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
// out is an external defined array:
// - with the only restriction that the dims have to be larger than the
// joined inputs.
// - no restrictions on the strides.
// The part of out, that is not overwritten by the join remains as is!!
class eval {
public:
vector<Param<T>> outputs;
Expand All @@ -126,6 +131,7 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
};
std::map<dim_t, eval> evals;
const cudaStream_t activeStream{getActiveStream()};
const dim_t *ostrides{out.strides().dims};
const size_t L2CacheSize{getL2CacheSize(getActiveDeviceId())};

// topspeed is achieved when byte size(in+out) ~= L2CacheSize
Expand All @@ -148,19 +154,23 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
// has to be called multiple times

// Group all arrays according to size
dim_t outOffset{0};
dim_t odim{0}, outOffset{0};
const dim_t *odims{out.dims().dims};
for (const Array<T> &iArray : inputs) {
const dim_t *idims{iArray.dims().dims};
for (int i = 0; i < AF_MAX_DIMS; ++i)
ARG_ASSERT(1, odims[i] >= idims[i]);
eval &e{evals[idims[jdim]]};
e.outputs.emplace_back(out.get() + outOffset, idims,
out.strides().dims);
e.outputs.emplace_back(out.get() + outOffset, idims, ostrides);
// Extend life of the returned node by saving the corresponding
// shared_ptr
e.nodePtrs.emplace_back(iArray.getNode());
e.nodes.push_back(e.nodePtrs.back().get());
e.ins.push_back(&iArray);
outOffset += idims[jdim] * out.strides().dims[jdim];
odim += idims[jdim];
outOffset = odim * ostrides[jdim];
}
ARG_ASSERT(1, odims[jdim] >= odim);

for (auto &eval : evals) {
auto &s{eval.second};
Expand All @@ -173,7 +183,12 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
auto outputIt{begin(s.outputs)};
for (const Array<T> *in : s.ins) {
if (in->isReady()) {
if (1LL + jdim >= in->ndims() && in->isLinear()) {
const dim_t *istrides{in->strides().dims};
bool lin = in->isLinear() & (ostrides[0] == 1);
for (int i{1}; i < in->ndims(); ++i) {
lin &= (ostrides[i] == istrides[i]);
}
if (lin) {
CUDA_CHECK(cudaMemcpyAsync(outputIt->ptr, in->get(),
in->elements() * sizeof(T),
cudaMemcpyHostToDevice,
Expand Down
16 changes: 13 additions & 3 deletions src/backend/opencl/join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,12 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
// will be called twice

// Group all arrays according to size
dim_t outOffset{0};
dim_t odim{0}, outOffset{0};
const dim_t *odims{out.dims().dims};
for (const Array<T> &iArray : inputs) {
const dim_t *idims{iArray.dims().dims};
for (int i = 0; i < AF_MAX_DIMS; ++i)
ARG_ASSERT(1, odims[i] >= idims[i]);
eval &e{evals[idims[jdim]]};
const Param output{
out.get(),
Expand All @@ -172,8 +175,10 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
e.nodePtrs.emplace_back(iArray.getNode());
e.nodes.push_back(e.nodePtrs.back().get());
e.ins.push_back(&iArray);
outOffset += idims[jdim] * ostrides[jdim];
odim += idims[jdim];
outOffset = odim * ostrides[jdim];
}
ARG_ASSERT(1, odims[jdim] >= odim);

for (auto &eval : evals) {
auto &s{eval.second};
Expand All @@ -186,7 +191,12 @@ void join(Array<T> &out, const int jdim, const vector<Array<T>> &inputs) {
auto outputIt{begin(s.outputs)};
for (const Array<T> *in : s.ins) {
if (in->isReady()) {
if (1LL + jdim >= in->ndims() && in->isLinear()) {
const dim_t *istrides{in->strides().dims};
bool lin = in->isLinear() & (ostrides[0] == 1);
for (int i{1}; i < in->ndims(); ++i) {
lin &= (ostrides[i] == istrides[i]);
}
if (lin) {
getQueue().enqueueCopyBuffer(
*in->get(), *outputIt->data,
in->getOffset() * sizeof(T),
Expand Down

0 comments on commit 9d72403

Please sign in to comment.