Skip to content

Commit

Permalink
use STL numerics
Browse files Browse the repository at this point in the history
  • Loading branch information
flaub committed Apr 29, 2024
1 parent c732ffa commit 8670588
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions risc0/circuit/rv32im-sys/cxx/extern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <algorithm>
#include <cassert>
#include <iostream>
#include <numeric>
#include <stdexcept>
#include <vector>

Expand Down Expand Up @@ -86,9 +87,14 @@ struct MachineContext {
void sortBytes();
};

struct AccumCell {
FpExt ram;
FpExt bytes;
};

struct AccumContext {
size_t steps;
std::array<std::vector<FpExt>, 2> pools;
std::vector<AccumCell> cells;

void calcPrefixProducts();
};
Expand Down Expand Up @@ -477,13 +483,7 @@ void MachineContext::sortRam() {

{
nvtx3::scoped_range range("update");
size_t pos = 0;
for (size_t i = 0; i < steps; i++) {
size_t idx = pos;
size_t count = ramIndex[i];
pos += count;
ramIndex[i] = idx;
}
std::exclusive_scan(ramIndex.begin(), ramIndex.end(), ramIndex.begin(), 0);
}
}

Expand Down Expand Up @@ -557,13 +557,7 @@ void MachineContext::sortBytes() {

{
nvtx3::scoped_range range("update");
size_t pos = 0;
for (size_t i = 0; i < steps; i++) {
size_t idx = pos;
size_t count = byteIndex[i];
pos += count;
byteIndex[i] = idx;
}
std::exclusive_scan(byteIndex.begin(), byteIndex.end(), byteIndex.begin(), 0);
}
}

Expand All @@ -581,7 +575,7 @@ void extern_plonkWriteAccum_ram(void* ctx,
std::array<Fp, 4> args) {
// printf("plonkWriteAccumRam\n");
AccumContext* actx = static_cast<AccumContext*>(ctx);
actx->pools[0][cycle] = FpExt(args[0], args[1], args[2], args[3]);
actx->cells[cycle].ram = FpExt(args[0], args[1], args[2], args[3]);
}

void extern_plonkWriteAccum_bytes(void* ctx,
Expand All @@ -590,29 +584,34 @@ void extern_plonkWriteAccum_bytes(void* ctx,
std::array<Fp, 4> args) {
// printf("plonkWriteAccumBytes\n");
AccumContext* actx = static_cast<AccumContext*>(ctx);
actx->pools[1][cycle] = FpExt(args[0], args[1], args[2], args[3]);
actx->cells[cycle].bytes = FpExt(args[0], args[1], args[2], args[3]);
}

AccumCell operator*(const AccumCell& lhs, const AccumCell& rhs) {
return AccumCell{lhs.ram * rhs.ram, lhs.bytes * rhs.bytes};
}

void AccumContext::calcPrefixProducts() {
// printf("calcPrefixProducts\n");
for (size_t i = 1; i < steps; i++) {
pools[0][i] *= pools[0][i - 1];
pools[1][i] *= pools[1][i - 1];
}
std::inclusive_scan(cells.begin(),
cells.end(),
cells.begin(),
std::multiplies<AccumCell>{},
AccumCell{FpExt(1), FpExt(1)});
}

std::array<Fp, 4>
extern_plonkReadAccum_ram(void* ctx, size_t cycle, const char* extra, std::array<Fp, 0> args) {
// printf("plonkReadAccumRam\n");
AccumContext* actx = static_cast<AccumContext*>(ctx);
const FpExt& item = actx->pools[0][cycle];
const FpExt& item = actx->cells[cycle].ram;
return {item.elems[0], item.elems[1], item.elems[2], item.elems[3]};
}

std::array<Fp, 4>
extern_plonkReadAccum_bytes(void* ctx, size_t cycle, const char* extra, std::array<Fp, 0> args) {
AccumContext* actx = static_cast<AccumContext*>(ctx);
const FpExt& item = actx->pools[1][cycle];
const FpExt& item = actx->cells[cycle].bytes;
return {item.elems[0], item.elems[1], item.elems[2], item.elems[3]};
}

Expand Down Expand Up @@ -652,8 +651,7 @@ void risc0_circuit_rv32im_sort_bytes(risc0_error* err, MachineContext* ctx) {
AccumContext* risc0_circuit_rv32im_accum_context_alloc(size_t steps) {
AccumContext* ctx = new AccumContext;
ctx->steps = steps;
ctx->pools[0].resize(steps, FpExt(1));
ctx->pools[1].resize(steps, FpExt(1));
ctx->cells.resize(steps, AccumCell{FpExt(1), FpExt(1)});
return ctx;
}

Expand Down

0 comments on commit 8670588

Please sign in to comment.