/
RangeSumQuery.java
202 lines (166 loc) · 6.59 KB
/
RangeSumQuery.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import java.util.function.BiFunction;
import java.lang.reflect.Array;
// https://leetcode.com/problems/range-sum-query-mutable
class NumArray {
SegmentTree<Integer> tree;
public NumArray(int[] nums) {
tree = new SegmentTree(
Arrays.stream(nums).boxed().toArray(),
SegmentTree.SegmentTreeFunction.SUM_QUERY,
Integer.class
);
}
public void update(int index, int val) {
tree.update(index, val);
}
public int sumRange(int left, int right) {
return tree.query(left, right);
}
/**
* Class that implements segment tree.
*/
class SegmentTree<T> {
/**
* Original array
*/
private final T[] originalArray;
// Size of the original array
private final int n;
// Segment tree
// Formula: Index of left child = 2 * i + 1, index of right child = 2 * i + 2.
private final T[] segmentTree;
// Segment Operation
private final SegmentTreeFunction<T> segmentTreeFunction;
/**
* Constructor.
*
* @param arr Original array.
* @param segmentOperation The operation on any 2 numbers that we want to compute and store it in our segment tree.
*/
public SegmentTree(T[] arr, SegmentTreeFunction<T> segmentTreeFunction, Class<T> clazz) {
this.originalArray = arr;
this.n = arr.length;
this.segmentTreeFunction = segmentTreeFunction;
this.segmentTree = (T[]) Array.newInstance(clazz, 4 * n);
// Build the segment tree
build(0, 0, n - 1);
// System.out.println("Segment tree = " + Arrays.toString(segmentTree));
}
/**
* Builds the segment tree.
*
* @param index Index of the segment tree whose value we are computing.
* @param low Lower bound (inclusive) of the range of the original array whose value we are computing.
* @param high Higher bound (inclusive) of the range of the original array whose value we are computing.
*/
private void build(int index, int low, int high) {
if (low == high) {
segmentTree[index] = originalArray[low];
return;
}
int mid = (low + high)/2;
// Left
build(2 * index + 1, low, mid);
// Right
build(2 * index + 2, mid + 1, high);
segmentTree[index] = segmentTreeFunction.compute(segmentTree[2 * index + 1], segmentTree[2 * index + 2]);
}
/**
* Queries the value for the given range [l, r];
*
* @param l Lower bound of range (inclusive).
* @param r Higher bound of range (inclusive).
* @return Result of the query.
*/
T query(int l, int r) {
return _query(l, r, 0, n - 1, 0);
}
private T _query(int l, int r, int low, int high, int index) {
// If the range that we are checking is completely within the range that is queried for.
if (low >= l && high <= r) {
return segmentTree[index];
}
// Does not lie at all
if (low > r || high < l) {
return segmentTreeFunction.fallbackValue;
} else {
/*if (low == high) {
return segmentTree[index];
}*/
// Overlaps
int mid = (low + high)/2;
return segmentTreeFunction.compute(_query(l, r, low, mid, 2 * index + 1), _query(l, r, mid + 1, high, 2 * index + 2));
}
}
/**
* Updates originalArra[pos] with the newVal.
*
* @param pos Position of original array to update.
* @param newVal New value to update the array with.
* @return True if update was successful, false otherwise.
*/
boolean update(int pos, T newVal) {
if (pos < 0 || pos >= n) {
return false;
}
_update(0, pos, 0, n - 1, newVal);
// System.out.println("After update, segment tree = " + Arrays.toString(segmentTree));
return true;
}
void _update(int index, int pos, int low, int high, T newVal) {
if (low == pos && high == pos) {
segmentTree[index] = newVal;
return;
}
if (high < pos || low > pos) {
// Out of range. Do nothing.
return;
}
int mid = (low + high)/2;
// Left
_update(2 * index + 1, pos, low, mid, newVal);
// Right
_update(2 * index + 2, pos, mid + 1, high, newVal);
segmentTree[index] = segmentTreeFunction.compute(segmentTree[2 * index + 1], segmentTree[2 * index + 2]);
}
/**
* Function that we are trying to perform with a segment tree.
*/
static class SegmentTreeFunction<T> {
// Segment Operation
private final BiFunction<T, T, T> segmentOperation;
// Fallback value when a range query cannot be answered
final T fallbackValue;
/**
* Constructor.
*
* @param segmentOperation Function operation.
* @param fallbackValue Fallback value.
*/
public SegmentTreeFunction(BiFunction<T, T, T> segmentOperation, T fallbackValue) {
this.fallbackValue = fallbackValue;
this.segmentOperation = segmentOperation;
}
/**
* Computes the function of the segment tree.
*
* @param val1 1st operand.
* @param val2 2nd operand.
* @return Computed value.
*/
T compute(T val1, T val2) {
return segmentOperation.apply(val1, val2);
}
// Commonly used segment tree operations
public static final SegmentTreeFunction<Integer> MAX_QUERY = new SegmentTreeFunction<Integer>((a, b) -> Math.max(a, b), Integer.MIN_VALUE);
public static final SegmentTreeFunction<Integer> MIN_QUERY = new SegmentTreeFunction<Integer>((a, b) -> Math.min(a, b), Integer.MAX_VALUE);
public static final SegmentTreeFunction<Integer> SUM_QUERY = new SegmentTreeFunction<Integer>((a, b) -> a + b, 0);
}
}
}
/**
* Your NumArray object will be instantiated and called as such:
* NumArray obj = new NumArray(nums);
* obj.update(index,val);
* int param_2 = obj.sumRange(left,right);
*/