New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Preorder reduces #4409
base: master
Are you sure you want to change the base?
Preorder reduces #4409
Conversation
Changes
|
oh, i overlooked this PR. this is pretty much what i had in mind, nice! |
if x.op in ReduceOps and not do_reduce: | ||
assert offs is None, "not available if we aren't doing reduce" | ||
return acc | ||
return [reduced_ops[x][i] for i in offs] if offs else reduced_ops[x] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think offs just removes unrolled shapes from iteration, so just returning reduced_ops[x] here is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe I goofed something elsewhere and this just covers it up, but this was so that, if there's an intermediate calculation between two reduceops, it gets brought back to the unrolled shape
ex: an unrolled standard deviation looks like this without of the offs
__kernel void r_4_32_8n4(__global float* data0, const __global float* data1) {
int gidx0 = get_group_id(0); /* 4 */
int lidx1 = get_local_id(0); /* 32 */
float acc0 = 0.0f;
int alu0 = ((gidx0*256)+(lidx1*8));
float val0 = data1[alu0];
float val1 = data1[alu0+1];
float val2 = data1[alu0+2];
float val3 = data1[alu0+3];
float val4 = data1[alu0+4];
float val5 = data1[alu0+5];
float val6 = data1[alu0+6];
float val7 = data1[alu0+7];
float acc1 = 0.0f;
data0[(gidx0*32)+lidx1] = (0.015625f*((val0-((val7+val6+val5+val4+val3+val2+val1+val0+acc0)*0.015625f))+acc1));
}
when what we want is this:
__kernel void r_4_32_8n4(__global float* data0, const __global float* data1) {
int gidx0 = get_group_id(0); /* 4 */
int lidx1 = get_local_id(0); /* 32 */
float acc0 = 0.0f;
int alu0 = ((gidx0*256)+(lidx1*8));
float val0 = data1[alu0];
float val1 = data1[alu0+1];
float val2 = data1[alu0+2];
float val3 = data1[alu0+3];
float val4 = data1[alu0+4];
float val5 = data1[alu0+5];
float val6 = data1[alu0+6];
float val7 = data1[alu0+7];
float acc1 = 0.0f;
float alu1 = ((val7+val6+val5+val4+val3+val2+val1+val0+acc0)*0.015625f);
data0[(gidx0*32)+lidx1] = (0.015625f*((val7-alu1)+(val6-alu1)+(val5-alu1)+(val4-alu1)+(val3-alu1)+(val2-alu1)+(val1-alu1)+(val0-alu1)+acc1));
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, i see. makes sense!
stores = self.global_store(-1, [NumNode(0)]*len(self.sts[-1].shape), acc) | ||
endif = self.uops.add(UOps.ENDIF, None, (barrier,), cachable=False) | ||
barrier = self.uops.add(UOps.BARRIER, None, (endif,), cachable=False) | ||
acc = self.global_load(-1, [NumNode(0)]*len(self.sts[-1].shape), barrier=barrier) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load local reduction back into every thread
if x.op is UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0], \ | ||
insert_before=insert_before) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers, insert_before=insert_before)] | ||
if x.op is UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) \ | ||
for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers, reduced_ops)] | ||
if x.op in ReduceOps and not do_reduce: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we want to do this read also if do_reduce and x is not the reduce we want to do here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure exactly what you mean but it seemed like the only time do_reduce
was set to true was when ast_parse
was directly fed the reduceop:
linerizer.py:256 self.ast_parse(reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
linearizer.py:298 self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # noqa: E501
I agree, we do want to do that read, but in that case we'd have to pass the reduceop we want to do to ast_parse
as well
I suppose that's not that hard, we just add another parameter or we could make do_reduce
an Optional[LazyOp]
and just pass it there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seemed like the only time do_reduce was set to true was when ast_parse was directly fed the reduceop
on the second-pass reduce, ast_parse will recurse down to the first-pass reduce's reduceop, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah but do_reduce
will get turned off when it recurses through
values = [self.ast_parse(v, acc, offs, loaded_buffers, reduced_ops, loop_ctx=loop_ctx, cache=cache) for v in x.src]
also re: #4419 : reduced_ops
is my way of accomplishing what I think you do with accs
, I think replacing accs
is prettier esp if it helps with parallel reduceops
@@ -349,7 +349,8 @@ def optimize_loops(self): | |||
# graph helper functions | |||
@functools.lru_cache(None) | |||
def get_recursive_parents(x:UOp, with_phi=False) -> Set[UOp]: | |||
return set.union(set(x.vin), *[get_recursive_parents(p, with_phi) for p in x.vin], set(acc_scope[x]) if with_phi else set()) | |||
return set.union(set(x.vin), *[get_recursive_parents(p, with_phi) for p in x.vin if p.uop is not UOps.BARRIER], \ | |||
set(acc_scope[x]) if with_phi else set()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't propagate get_recursive_parents
through a barrier;
@Qazalin turns out preordering them was p easy
drafting this PR because it includes the linearizer changes (so I could test that it works for asts with multiple reduceops). I can break it down into PRs for the ordering code & linearizer changes if that would be easier