Skip to content

Commit

Permalink
gsum works on int64 column, closes #1647 #3464 (#3737)
Browse files Browse the repository at this point in the history
  • Loading branch information
jangorecki authored and mattdowle committed Aug 14, 2019
1 parent e7a00de commit a8e926a
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 28 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@

26. Column binding of zero column `data.table` will now work as expected, [#3334](https://github.com/Rdatatable/data.table/issues/3334). Thanks to @kzenstratus for the report.

27. `integer64` sum-by-group is now properly optimized, [#1647](https://github.com/Rdatatable/data.table/issues/1647), [#3464](https://github.com/Rdatatable/data.table/issues/3464). Thanks to @mlandry22-h2o for the report.

#### NOTES

1. `rbindlist`'s `use.names="check"` now emits its message for automatic column names (`"V[0-9]+"`) too, [#3484](https://github.com/Rdatatable/data.table/pull/3484). See news item 5 of v1.12.2 below.
Expand Down
14 changes: 14 additions & 0 deletions inst/tests/tests.Rraw
Original file line number Diff line number Diff line change
Expand Up @@ -15657,6 +15657,20 @@ test(2076.02, X[on=Y], data.table(a=2:3, b=c(2L,NA_integer_), d=2:1))
test(2076.03, X[on=3], error="When on= is provided but not i=, on= must be a named list or data.table|frame, and a natural join")
test(2076.04, X[on=list(3)], error="When on= is provided but not i=, on= must be a named list or data.table|frame, and a natural join")

# gsum int64 support #1647, #3464
if (test_bit64) {
d = data.table(g=1:2, i32=c(2L,-1L,3L,4L), i64=as.integer64(c(2L,-1L,3L,4L)))
int64_int32_match = function(x, y) isTRUE(all.equal(lapply(x, as.integer), lapply(y, as.integer)))
test(2077.01, int64_int32_match(d[, sum(i32), g], d[, sum(i64), g]))
test(2077.02, int64_int32_match(d[, sum(i32, na.rm=TRUE), g], d[, sum(i64, na.rm=TRUE), g]))
d[3L, c("i32","i64") := list(NA_integer_, as.integer64(NA))] # some NA group
test(2077.03, int64_int32_match(d[, sum(i32), g], d[, sum(i64), g]))
test(2077.04, int64_int32_match(d[, sum(i32, na.rm=TRUE), g], d[, sum(i64, na.rm=TRUE), g]))
d[1L, c("i32","i64") := list(NA_integer_, as.integer64(NA))] # all NA group
test(2077.05, int64_int32_match(d[, sum(i32), g], d[, sum(i64), g]))
test(2077.06, int64_int32_match(d[, sum(i32, na.rm=TRUE), g], d[, sum(i64, na.rm=TRUE), g]))
}


###################################
# Add new tests above this line #
Expand Down
115 changes: 87 additions & 28 deletions src/gsumm.c
Original file line number Diff line number Diff line change
Expand Up @@ -426,37 +426,96 @@ SEXP gsum(SEXP x, SEXP narmArg)
}
} break;
case REALSXP: {
const double *restrict gx = gather(x, &anyNA);
ans = PROTECT(allocVector(REALSXP, ngrp));
double *restrict ansp = REAL(ans);
memset(ansp, 0, ngrp*sizeof(double));
if (!narm || !anyNA) {
#pragma omp parallel for num_threads(getDTthreads())
for (int h=0; h<highSize; h++) {
double *restrict _ans = ansp + (h<<shift);
for (int b=0; b<nBatch; b++) {
const int pos = counts[ b*highSize + h ];
const int howMany = ((h==highSize-1) ? (b==nBatch-1?lastBatchSize:batchSize) : counts[ b*highSize + h + 1 ]) - pos;
const double *my_gx = gx + b*batchSize + pos;
const uint16_t *my_low = low + b*batchSize + pos;
for (int i=0; i<howMany; i++) {
_ans[my_low[i]] += my_gx[i]; // let NA propagate when !narm
if (!INHERITS(x, char_integer64)) {
const double *restrict gx = gather(x, &anyNA);
ans = PROTECT(allocVector(REALSXP, ngrp));
double *restrict ansp = REAL(ans);
memset(ansp, 0, ngrp*sizeof(double));
if (!narm || !anyNA) {
#pragma omp parallel for num_threads(getDTthreads())
for (int h=0; h<highSize; h++) {
double *restrict _ans = ansp + (h<<shift);
for (int b=0; b<nBatch; b++) {
const int pos = counts[ b*highSize + h ];
const int howMany = ((h==highSize-1) ? (b==nBatch-1?lastBatchSize:batchSize) : counts[ b*highSize + h + 1 ]) - pos;
const double *my_gx = gx + b*batchSize + pos;
const uint16_t *my_low = low + b*batchSize + pos;
for (int i=0; i<howMany; i++) {
_ans[my_low[i]] += my_gx[i]; // let NA propagate when !narm
}
}
}
} else {
// narm==true and anyNA==true
#pragma omp parallel for num_threads(getDTthreads())
for (int h=0; h<highSize; h++) {
double *restrict _ans = ansp + (h<<shift);
for (int b=0; b<nBatch; b++) {
const int pos = counts[ b*highSize + h ];
const int howMany = ((h==highSize-1) ? (b==nBatch-1?lastBatchSize:batchSize) : counts[ b*highSize + h + 1 ]) - pos;
const double *my_gx = gx + b*batchSize + pos;
const uint16_t *my_low = low + b*batchSize + pos;
for (int i=0; i<howMany; i++) {
const double elem = my_gx[i];
if (!ISNAN(elem)) _ans[my_low[i]] += elem;
}
}
}
}
} else {
// narm==true and anyNA==true
#pragma omp parallel for num_threads(getDTthreads())
for (int h=0; h<highSize; h++) {
double *restrict _ans = ansp + (h<<shift);
for (int b=0; b<nBatch; b++) {
const int pos = counts[ b*highSize + h ];
const int howMany = ((h==highSize-1) ? (b==nBatch-1?lastBatchSize:batchSize) : counts[ b*highSize + h + 1 ]) - pos;
const double *my_gx = gx + b*batchSize + pos;
const uint16_t *my_low = low + b*batchSize + pos;
for (int i=0; i<howMany; i++) {
const double elem = my_gx[i];
if (!ISNAN(elem)) _ans[my_low[i]] += elem;
} else { // int64
const int64_t *restrict gx = gather(x, &anyNA);
ans = PROTECT(allocVector(REALSXP, ngrp));
int64_t *restrict ansp = (int64_t *)REAL(ans);
memset(ansp, 0, ngrp*sizeof(int64_t));
if (!anyNA) {
#pragma omp parallel for num_threads(getDTthreads())
for (int h=0; h<highSize; h++) {
int64_t *restrict _ans = ansp + (h<<shift);
for (int b=0; b<nBatch; b++) {
const int pos = counts[ b*highSize + h ];
const int howMany = ((h==highSize-1) ? (b==nBatch-1?lastBatchSize:batchSize) : counts[ b*highSize + h + 1 ]) - pos;
const int64_t *my_gx = gx + b*batchSize + pos;
const uint16_t *my_low = low + b*batchSize + pos;
for (int i=0; i<howMany; i++) {
_ans[my_low[i]] += my_gx[i]; // does not propagate INT64 for !narm
}
}
}
} else { // narm==true/false and anyNA==true
if (!narm) {
#pragma omp parallel for num_threads(getDTthreads())
for (int h=0; h<highSize; h++) {
int64_t *restrict _ans = ansp + (h<<shift);
for (int b=0; b<nBatch; b++) {
const int pos = counts[ b*highSize + h ];
const int howMany = ((h==highSize-1) ? (b==nBatch-1?lastBatchSize:batchSize) : counts[ b*highSize + h + 1 ]) - pos;
const int64_t *my_gx = gx + b*batchSize + pos;
const uint16_t *my_low = low + b*batchSize + pos;
for (int i=0; i<howMany; i++) {
const int64_t elem = my_gx[i];
if (elem!=INT64_MIN) {
_ans[my_low[i]] += elem;
} else {
_ans[my_low[i]] = INT64_MIN;
break;
}
}
}
}
} else {
#pragma omp parallel for num_threads(getDTthreads())
for (int h=0; h<highSize; h++) {
int64_t *restrict _ans = ansp + (h<<shift);
for (int b=0; b<nBatch; b++) {
const int pos = counts[ b*highSize + h ];
const int howMany = ((h==highSize-1) ? (b==nBatch-1?lastBatchSize:batchSize) : counts[ b*highSize + h + 1 ]) - pos;
const int64_t *my_gx = gx + b*batchSize + pos;
const uint16_t *my_low = low + b*batchSize + pos;
for (int i=0; i<howMany; i++) {
const int64_t elem = my_gx[i];
if (elem!=INT64_MIN) _ans[my_low[i]] += elem;
}
}
}
}
}
Expand Down

0 comments on commit a8e926a

Please sign in to comment.