Skip to content

Commit

Permalink
Clear the thread_local vectors and stringstream in case of exception
Browse files Browse the repository at this point in the history
  • Loading branch information
willyborn committed Oct 15, 2023
1 parent 3b2919d commit c5713b5
Showing 1 changed file with 63 additions and 59 deletions.
122 changes: 63 additions & 59 deletions src/backend/opencl/jit.cpp
Expand Up @@ -188,71 +188,75 @@ __kernel void )JIT";
thread_local stringstream outOffsetStream;
thread_local stringstream inOffsetsStream;
thread_local stringstream opsStream;
thread_local stringstream kerStream;

int oid{0};
for (size_t i{0}; i < full_nodes.size(); i++) {
const auto& node{full_nodes[i]};
const auto& ids_curr{full_ids[i]};
// Generate input parameters, only needs current id
node->genParams(inParamStream, ids_curr.id, is_linear);
// Generate input offsets, only needs current id
node->genOffsets(inOffsetsStream, ids_curr.id, is_linear);
// Generate the core function body, needs children ids as well
node->genFuncs(opsStream, ids_curr);
size_t output_idx{0};
while (output_idx < output_ids.size() &&
output_ids[output_idx] != ids_curr.id)
++output_idx;
if (output_idx != output_ids.size()) {
outParamStream << "__global "
<< full_nodes[ids_curr.id]->getTypeStr() << " *out"
<< oid << ", int offset" << oid << ",\n";
// Apply output offset
outOffsetStream << "\nout" << oid << " += offset" << oid << ';';
// Generate code to write the output
opsStream << "out" << output_idx << "[idx] = val" << ids_curr.id
<< ";\n";
++oid;
string ret;
try {
int oid{0};
for (size_t i{0}; i < full_nodes.size(); i++) {
const auto& node{full_nodes[i]};
const auto& ids_curr{full_ids[i]};
// Generate input parameters, only needs current id
node->genParams(inParamStream, ids_curr.id, is_linear);
// Generate input offsets, only needs current id
node->genOffsets(inOffsetsStream, ids_curr.id, is_linear);
// Generate the core function body, needs children ids as well
node->genFuncs(opsStream, ids_curr);
size_t output_idx{0};
while (output_idx < output_ids.size() &&
output_ids[output_idx] != ids_curr.id)
++output_idx;
if (output_idx != output_ids.size()) {
outParamStream
<< "__global " << full_nodes[ids_curr.id]->getTypeStr()
<< " *out" << oid << ", int offset" << oid << ",\n";
// Apply output offset
outOffsetStream << "\nout" << oid << " += offset" << oid << ';';
// Generate code to write the output
opsStream << "out" << output_idx << "[idx] = val" << ids_curr.id
<< ";\n";
++oid;
}
}
}

thread_local stringstream kerStream;
kerStream << kernelVoid << funcName << "(\n"
<< inParamStream.str() << outParamStream.str() << dimParams << ")"
<< blockStart;
if (is_linear) {
kerStream << linearInit << inOffsetsStream.str()
<< outOffsetStream.str() << '\n';
if (loop0) kerStream << linearLoop0Start;
kerStream << "\n\n" << opsStream.str();
if (loop0) kerStream << linearLoop0End;
kerStream << linearEnd;
} else {
if (loop0) {
kerStream << stridedLoop0Init << outOffsetStream.str() << '\n'
<< stridedLoop0Start;
kerStream << kernelVoid << funcName << "(\n"
<< inParamStream.str() << outParamStream.str() << dimParams
<< ")" << blockStart;
if (is_linear) {
kerStream << linearInit << inOffsetsStream.str()
<< outOffsetStream.str() << '\n';
if (loop0) kerStream << linearLoop0Start;
kerStream << "\n\n" << opsStream.str();
if (loop0) kerStream << linearLoop0End;
kerStream << linearEnd;
} else {
kerStream << stridedLoopNInit << outOffsetStream.str() << '\n';
if (loop3) kerStream << stridedLoop3Init;
if (loop1) kerStream << stridedLoop1Init << stridedLoop1Start;
if (loop3) kerStream << stridedLoop3Start;
if (loop0) {
kerStream << stridedLoop0Init << outOffsetStream.str() << '\n'
<< stridedLoop0Start;
} else {
kerStream << stridedLoopNInit << outOffsetStream.str() << '\n';
if (loop3) kerStream << stridedLoop3Init;
if (loop1) kerStream << stridedLoop1Init << stridedLoop1Start;
if (loop3) kerStream << stridedLoop3Start;
}
kerStream << "\n\n" << inOffsetsStream.str() << opsStream.str();
if (loop3) kerStream << stridedLoop3End;
if (loop1) kerStream << stridedLoop1End;
if (loop0) kerStream << stridedLoop0End;
kerStream << stridedEnd;
}
kerStream << "\n\n" << inOffsetsStream.str() << opsStream.str();
if (loop3) kerStream << stridedLoop3End;
if (loop1) kerStream << stridedLoop1End;
if (loop0) kerStream << stridedLoop0End;
kerStream << stridedEnd;
kerStream << blockEnd;
ret = kerStream.str();
} catch (...) {
// Prepare for next round
inParamStream.str("");
outParamStream.str("");
inOffsetsStream.str("");
outOffsetStream.str("");
opsStream.str("");
kerStream.str("");
throw;
}
kerStream << blockEnd;
const string ret{kerStream.str()};

// Prepare for next round, limit memory
inParamStream.str("");
outParamStream.str("");
inOffsetsStream.str("");
outOffsetStream.str("");
opsStream.str("");
kerStream.str("");

return ret;
}
Expand Down

0 comments on commit c5713b5

Please sign in to comment.