Skip to content

Commit

Permalink
Merge pull request #37140 from peterfpeterson/ewm3412_pulse_indexer_i…
Browse files Browse the repository at this point in the history
…terator_ornlnext

Add forward iterator to PulseIndexer - ornl-next
  • Loading branch information
peterfpeterson committed Apr 8, 2024
2 parents 98ab6ff + c7ba150 commit ecf3f7a
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 36 deletions.
42 changes: 42 additions & 0 deletions Framework/DataHandling/inc/MantidDataHandling/PulseIndexer.h
Expand Up @@ -32,6 +32,43 @@ namespace DataHandling {
*/
class MANTID_DATAHANDLING_DLL PulseIndexer {
public:
// ----------------------------------------- input iterator allows for read-only access
struct IteratorValue {
std::size_t pulseIndex;
std::size_t eventIndexStart;
std::size_t eventIndexStop;
};

struct MANTID_DATAHANDLING_DLL Iterator {
using iterator_category = std::input_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = IteratorValue;

Iterator(const PulseIndexer *indexer, const size_t pulseIndex)
: m_indexer(indexer), m_lastPulseIndex(m_indexer->getLastPulseIndex()) {
m_value.pulseIndex = pulseIndex;
calculateEventRange();
}

const IteratorValue &operator*() const;

// prefix increment ++iter
Iterator &operator++();
// postfix increment iter++ is not needed

bool operator==(const PulseIndexer::Iterator &other) const;
bool operator!=(const PulseIndexer::Iterator &other) const;

private:
const PulseIndexer *m_indexer;
const size_t m_lastPulseIndex;

bool calculateEventRange();

IteratorValue m_value;
};

// ----------------------------------------- pulse indexer class start
PulseIndexer(std::shared_ptr<std::vector<uint64_t>> event_index, const std::size_t firstEventIndex,
const std::size_t numEvents, const std::string &entry_name, const std::vector<size_t> &pulse_roi);

Expand All @@ -53,6 +90,11 @@ class MANTID_DATAHANDLING_DLL PulseIndexer {
*/
size_t getStopEventIndex(const size_t pulseIndex) const;

const Iterator cbegin() const;
const Iterator cend() const;
Iterator begin() const;
Iterator end() const;

private:
PulseIndexer(); // do not allow empty constructor

Expand Down
15 changes: 3 additions & 12 deletions Framework/DataHandling/src/LoadErrorEventsNexus.cpp
Expand Up @@ -115,22 +115,13 @@ void LoadErrorEventsNexus::exec() {

const PulseIndexer pulseIndexer(event_index, event_index->at(0), numEvents, "bank_error_events",
std::vector<size_t>());
const auto firstPulseIndex = pulseIndexer.getFirstPulseIndex();
const auto lastPulseIndex = pulseIndexer.getLastPulseIndex();

for (std::size_t pulseIndex = firstPulseIndex; pulseIndex < lastPulseIndex; pulseIndex++) {
// determine range of events for the pulse
const auto eventIndexRange = pulseIndexer.getEventIndexRange(pulseIndex);
if (eventIndexRange.first > numEvents)
break;
else if (eventIndexRange.first == eventIndexRange.second)
continue;

for (const auto &pulseIter : pulseIndexer) {
// Save the pulse time at this index for creating those events
const auto &pulsetime = bankPulseTimes->pulseTime(pulseIndex);
const auto &pulsetime = bankPulseTimes->pulseTime(pulseIter.pulseIndex);

// loop through events associated with a single pulse
for (std::size_t eventIndex = eventIndexRange.first; eventIndex < eventIndexRange.second; ++eventIndex) {
for (std::size_t eventIndex = pulseIter.eventIndexStart; eventIndex < pulseIter.eventIndexStop; ++eventIndex) {
const auto tof = static_cast<double>(event_times[eventIndex]);
ev.addEventQuickly(Mantid::Types::Event::TofEvent(tof, pulsetime));
min_tof = std::min(min_tof, tof);
Expand Down
21 changes: 6 additions & 15 deletions Framework/DataHandling/src/ProcessBankData.cpp
Expand Up @@ -118,25 +118,16 @@ void ProcessBankData::run() {
}

const PulseIndexer pulseIndexer(event_index, startAt, numEvents, entry_name, pulseROI);
const auto firstPulseIndex = pulseIndexer.getFirstPulseIndex(); // for whole file reading is 0
const auto lastPulseIndex = pulseIndexer.getLastPulseIndex(); // for whole file reading is event_index->size()

// loop over all pulses
for (std::size_t pulseIndex = firstPulseIndex; pulseIndex < lastPulseIndex; pulseIndex++) {
// determine range of events for the pulse
const auto eventIndexRange = pulseIndexer.getEventIndexRange(pulseIndex);
if (eventIndexRange.first >= numEvents) // already process all the events that were requested
break;
else if (eventIndexRange.first == eventIndexRange.second) // empty range
continue;

for (const auto &pulseIter : pulseIndexer) {
// Save the pulse time at this index for creating those events
const auto &pulsetime = thisBankPulseTimes->pulseTime(pulseIndex);
const int logPeriodNumber = thisBankPulseTimes->periodNumber(pulseIndex);
const int periodIndex = logPeriodNumber - 1;
const auto &pulsetime = thisBankPulseTimes->pulseTime(pulseIter.pulseIndex);
const int logPeriodNumber = thisBankPulseTimes->periodNumber(pulseIter.pulseIndex);
const auto periodIndex = static_cast<size_t>(logPeriodNumber - 1);

// loop through events associated with a single pulse
for (std::size_t eventIndex = eventIndexRange.first; eventIndex < eventIndexRange.second; ++eventIndex) {
for (std::size_t eventIndex = pulseIter.eventIndexStart; eventIndex < pulseIter.eventIndexStop; ++eventIndex) {
// We cached a pointer to the vector<tofEvent> -> so retrieve it and add
// the event
const detid_t &detId = static_cast<detid_t>((*event_detid)[eventIndex]);
Expand Down Expand Up @@ -189,7 +180,7 @@ void ProcessBankData::run() {
} // valid detector IDs
} // for events in pulse
// check if cancelled after each 100s of pulses (assumes 60Hz)
if ((pulseIndex % 6000 == 0) && alg->getCancel())
if ((pulseIter.pulseIndex % 6000 == 0) && alg->getCancel())
return;
} // for pulses

Expand Down
84 changes: 83 additions & 1 deletion Framework/DataHandling/src/PulseIndexer.cpp
Expand Up @@ -29,11 +29,38 @@ PulseIndexer::PulseIndexer(std::shared_ptr<std::vector<uint64_t>> event_index, c
if (pulse_roi.size() % 2 != 0)
throw std::runtime_error("Invalid size for pulsetime roi, must be even or empty");

// new roi is the intersection of these two
auto roi_combined = Mantid::Kernel::ROI::calculate_intersection(m_roi, pulse_roi);
m_roi.clear();
m_roi.assign(roi_combined.cbegin(), roi_combined.cend());
m_roi_complex = bool(m_roi.size() > 2);
}

// determine if should trim the front end to remove empty pulses
auto firstPulseIndex = m_roi.front();
auto eventRange = this->getEventIndexRange(firstPulseIndex);
while (eventRange.first == eventRange.second) {
++firstPulseIndex;
eventRange = this->getEventIndexRange(firstPulseIndex);
}

// determine if should trim the back end to remove empty pulses
auto lastPulseIndex = m_roi.back();
eventRange = this->getEventIndexRange(lastPulseIndex - 1);
while (eventRange.first == eventRange.second) {
--lastPulseIndex;
eventRange = this->getEventIndexRange(lastPulseIndex - 1);
}

// update the value if it has changed
if ((firstPulseIndex != m_roi.front()) || (lastPulseIndex != m_roi.back())) {
auto roi_combined = Mantid::Kernel::ROI::calculate_intersection(m_roi, {firstPulseIndex, lastPulseIndex});
m_roi.clear();
m_roi.assign(roi_combined.cbegin(), roi_combined.cend());
}

// after the updates, recalculate if the roi is more than a single region
m_roi_complex = bool(m_roi.size() > 2);
}

/**
Expand Down Expand Up @@ -98,7 +125,7 @@ size_t PulseIndexer::determineLastPulseIndex() const {
break;
}

return static_cast<size_t>(m_event_index->size() - std::distance(m_event_index->crbegin(), event_index_iter));
return m_event_index->size() - static_cast<size_t>(std::distance(m_event_index->crbegin(), event_index_iter));
}

size_t PulseIndexer::getFirstPulseIndex() const { return m_roi.front(); }
Expand Down Expand Up @@ -188,4 +215,59 @@ size_t PulseIndexer::getStopEventIndex(const size_t pulseIndex) const {
return m_numEvents;
}

// ----------------------------------------- range for iteration
const PulseIndexer::Iterator PulseIndexer::cbegin() const {
return PulseIndexer::Iterator(this, this->getFirstPulseIndex());
}

const PulseIndexer::Iterator PulseIndexer::cend() const {
return PulseIndexer::Iterator(this, this->getLastPulseIndex());
}

PulseIndexer::Iterator PulseIndexer::begin() const { return PulseIndexer::Iterator(this, this->getFirstPulseIndex()); }

PulseIndexer::Iterator PulseIndexer::end() const { return PulseIndexer::Iterator(this, this->getLastPulseIndex()); }

// ----------------------------------------- input iterator implementation

/// returns true if the range is empty
bool PulseIndexer::Iterator::calculateEventRange() {
const auto eventRange = m_indexer->getEventIndexRange(m_value.pulseIndex);
m_value.eventIndexStart = eventRange.first;
m_value.eventIndexStop = eventRange.second;

return m_value.eventIndexStart == m_value.eventIndexStop;
}

const PulseIndexer::IteratorValue &PulseIndexer::Iterator::operator*() const { return m_value; }

PulseIndexer::Iterator &PulseIndexer::Iterator::operator++() {
++m_value.pulseIndex;
// cache the final pulse index to use
const auto lastPulseIndex = m_indexer->m_roi.back();

// advance to the next included pulse
while ((m_value.pulseIndex < lastPulseIndex) && (!m_indexer->includedPulse(m_value.pulseIndex)))
++m_value.pulseIndex;

// return early if this has advanced to the end
if (m_value.pulseIndex >= lastPulseIndex)
return *this;

while (this->calculateEventRange() && (m_value.pulseIndex < lastPulseIndex)) {
++m_value.pulseIndex; // move forward a pulse while there is
}

return *this;
}

bool PulseIndexer::Iterator::operator==(const PulseIndexer::Iterator &other) const {
if (this->m_indexer != other.m_indexer)
return false;
else
return this->m_value.pulseIndex == other.m_value.pulseIndex;
}

bool PulseIndexer::Iterator::operator!=(const PulseIndexer::Iterator &other) const { return !(*this == other); }

} // namespace Mantid::DataHandling
50 changes: 42 additions & 8 deletions Framework/DataHandling/test/PulseIndexerTest.h
Expand Up @@ -6,9 +6,9 @@
// SPDX - License - Identifier: GPL - 3.0 +
#pragma once

#include <cxxtest/TestSuite.h>

#include "MantidDataHandling/PulseIndexer.h"
#include <cxxtest/TestSuite.h>
#include <iostream>

using Mantid::DataHandling::PulseIndexer;

Expand Down Expand Up @@ -62,6 +62,26 @@ class PulseIndexerTest : public CxxTest::TestSuite {
TS_ASSERT_EQUALS(indexer.getStartEventIndex(i), indexer.getStopEventIndex(i));
TS_ASSERT_EQUALS(indexer.getStopEventIndex(i), total_events);
}

// check the iterator
TS_ASSERT_DIFFERS(indexer.cbegin(), indexer.cend());

{ // explicit for loop
size_t count{0};
for (auto iter = indexer.cbegin(); iter != indexer.cend(); ++iter)
count++;
TS_ASSERT_EQUALS(count, indexer.getLastPulseIndex() - indexer.getFirstPulseIndex());
}

{ // range based for loop
size_t count{0};

for (const auto &iter : indexer) { // requires begin() and end()
(void)iter; // to quiet unused arg warning
count++;
}
TS_ASSERT_EQUALS(count, indexer.getLastPulseIndex() - indexer.getFirstPulseIndex());
}
}

std::shared_ptr<std::vector<uint64_t>> generate_nonConstant() {
Expand Down Expand Up @@ -167,7 +187,7 @@ class PulseIndexerTest : public CxxTest::TestSuite {
eventIndices->push_back(i);

constexpr size_t first_pulse_index{1};
const size_t last_pulse_index{eventIndices->size() - 1};
const size_t last_pulse_index{eventIndices->size() - 2};
const auto start_event_index = eventIndices->operator[](1);
const size_t total_events{eventIndices->back() - eventIndices->operator[](first_pulse_index) - 1};
std::vector<size_t> roi;
Expand Down Expand Up @@ -221,27 +241,41 @@ class PulseIndexerTest : public CxxTest::TestSuite {
PulseIndexer indexer(eventIndices, start_event_index, total_events, entry_name, roi);

TS_ASSERT_EQUALS(indexer.getFirstPulseIndex(), roi.front());
TS_ASSERT_EQUALS(indexer.getLastPulseIndex(), last_pulse_index);
TS_ASSERT_EQUALS(indexer.getLastPulseIndex(), last_pulse_index - 2); // roi gets rid of more

auto toEventIndex = [eventIndices, start_event_index](const size_t pulseIndex) {
return eventIndices->operator[](pulseIndex) - start_event_index;
};

const auto exp_start_event = eventIndices->operator[](indexer.getFirstPulseIndex()) - start_event_index;
const auto exp_total_event = total_events - start_event_index;
const auto exp_total_event = toEventIndex(4) - toEventIndex(2);

// check the individual event indices
assert_indices_equal(indexer, 0, exp_start_event, exp_start_event); // exclude before
assert_indices_equal(indexer, 1, exp_start_event, exp_start_event); // exclude before
assert_indices_equal(indexer, 2, toEventIndex(2), toEventIndex(3)); // include
assert_indices_equal(indexer, 3, toEventIndex(3), toEventIndex(4)); // include
assert_indices_equal(indexer, 4, toEventIndex(4), toEventIndex(4)); // exclude
assert_indices_equal(indexer, 5, toEventIndex(5), toEventIndex(5)); // exclude
assert_indices_equal(indexer, 6, toEventIndex(6), exp_total_event); // include
assert_indices_equal(indexer, 4, total_events, total_events); // exclude
assert_indices_equal(indexer, 5, total_events, total_events); // exclude
assert_indices_equal(indexer, 6, total_events, total_events); // exclude
assert_indices_equal(indexer, 7, total_events, total_events); // exclude due to number of events
assert_indices_equal(indexer, 8, total_events, total_events); // exclude after
assert_indices_equal(indexer, 9, total_events, total_events); // exclude after
assert_indices_equal(indexer, 10, total_events, total_events); // exclude out of range

// check the iterator
TS_ASSERT_DIFFERS(indexer.begin(), indexer.end());
TS_ASSERT_DIFFERS(indexer.cbegin(), indexer.cend());

// range based for loop
size_t num_steps{0};
size_t num_events{0};
for (const auto &iter : indexer) { // requires begin() and end()
num_events += (iter.eventIndexStop - iter.eventIndexStart);
num_steps++;
}
TS_ASSERT_EQUALS(num_events, exp_total_event);
TS_ASSERT_EQUALS(num_steps, 2); // calculated by hand
}
}
};

0 comments on commit ecf3f7a

Please sign in to comment.