Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] resnet batchnorm backward fusion spec #4370

Draft
wants to merge 36 commits into
base: master
Choose a base branch
from

Conversation

chaosagent
Copy link
Contributor

small example for easy inspection for now

@Qazalin
Copy link
Collaborator

Qazalin commented May 1, 2024

thanks - adding this to the scheduler roadmap!

# easy case: merge 4 reduces in backward into 1
# double reduce case: merge stat calculations from 2 to 1 (be careful of long reduces!)
# sum(x - \bar{x}): one kernel just calculates this, can be eliminated
# pre-expand fusion: is it fast? -2 kernels possible, 1 fw, 1 bw
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this refer to E_2_16_64n1 +E_2048
(graph ref: https://tiny-tools-client.vercel.app/?id=f7b72a41bad14974970329924c89b2c0)
?
#4235 could do this, it won't because <LB METAL (2, 16, 8, 8) float (<UnaryOps.CAST: 3>, None)> is forced_realize. I think it breaks the API if we fuse a forced_realize parent with its child.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am referring to E2_16_64n1 (full forward with relu) and E2_16_64 (full backward through batchnorm). The first can be fused with the next conv, and the latter can be fused with the next backward conv. (E_2048 simulates the backward from the next layer, plus relu backward)

This test case does not have the convs to focus on batchnorm, so it cannot happen here. will add more cases.

Copy link
Contributor

github-actions bot commented May 2, 2024

This branch currently is behind tinygrad/master. The line count difference bot is disabled.

@chaosagent
Copy link
Contributor Author

chaosagent commented May 2, 2024

Added detailed behavior spec. The fusion decision for the parallel reduces should be straightforward and "free" performance wise, but fusing conv(a + b) may be bad in some cases. Need a heuristic to decide when a buffer counts as a "big" buffer, and when one is a "small" buffer.

The specs so far can remove 8 out of 14 extraneous memory passes in bn(conv2d).relu(), with an estimated time saving of 33ms on BS=256 resnet.

(Edited because I posted fake news)

@chaosagent
Copy link
Contributor Author

chaosagent commented May 5, 2024

the scheduler change is a little tricky, since you need to make sure that each grouping is a contiguous sub-DAG. My solution to this is currently to do the grouping while toposorting, which should work for the specific bn training case, but is it possible to make it clean?

Probably deferring contiguous reduces until you run out of nodes in queue then grouping them would work.

@Qazalin
Copy link
Collaborator

Qazalin commented May 5, 2024

I need to think about the scheduler change a bit more, but in general we don't wanna do merge schedules, if there is grouping to be done it should be here https://github.com/tinygrad/tinygrad/blob/master/tinygrad/engine/schedule.py#L225-L228

new_arg = MemBuffer(new_lbs.index(old_lbs[ast.arg.idx]), ast.arg.dtype, ast.arg.st) if ast.op in [BufferOps.LOAD, BufferOps.STORE] else ast.arg
return LazyOp(ast.op, tuple(_replace_bufis(x, old_lbs, new_lbs) for x in ast.src), new_arg)

def _merge_prescheduled(prescheduled: List[_LBScheduleItem]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've gone through this route in multioutput,

def _schedule_outputs(outs:List[_LBScheduleItem], reduce_for_op:Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem:

I think you need to rebuild the entire AST.

@chaosagent
Copy link
Contributor Author

i implemented deferring contiguous reduces until you run out of nodes in queue. it seems to work quite well, and it passed all the tests I had (very surprised that this happened smoothly). huge mess though.

i am prototyping with merge_prescheduled because i need to toposort to find these fusion opportunities (i don't see a way to analyze the graph locally to find them), and I need shapetracker information to match (lazybuffer, st) read pairs, conveniently provided by preschedule.

the rules as implemented are a little in the style of "performance heuristic" though, which is a little different from the other rules we have. is it possible to move back to pure scheduling land?

@Qazalin
Copy link
Collaborator

Qazalin commented May 6, 2024

I think all of your fusion targets are children of E_2048
Screenshot 2024-05-06 at 4 37 16 PM

https://tiny-tools-client.vercel.app/?id=3ef8c4a72b0c4999acca0dff9288b2fa

could traversing its local graph work?

@chaosagent
Copy link
Contributor Author

chaosagent commented May 6, 2024

Some of them are also children of the forward pass. How can we tell if there is a path forward -> BN forward -> stuff -> fusion targets so that we don't fuse bn forward and backward?

The first attempt did toposort + local children. But if you don't have all inputs before E_2048 (with BN we are lucky), you will have to get lucky with the toposort order (most of the tests will not pass)

# match by input + ST and two shapes? start with contigouous input only, check shapes (should determine reduces)

# what if same input + st but one is early and another is late?
check_schedule([x.sum(0, keepdim=True) + a, (a + b).sum()], 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a real-world case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be... maybe if you have a bias weight and

out.sum(0) + bias -> next layer

(bias**2).sum() -> LARS

?

check_schedule([sum1, (x + sum1).sum()], 2)
del sum1

# super tricky crossing dag case
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think fusing this is faster?
Screenshot 2024-05-06 at 7 01 08 PM

Copy link
Contributor Author

@chaosagent chaosagent May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The (conservative) heuristic I am using is that this fusion should never add extra loads from bijective shapetrackers. If a shapetracker is bijective, then its size matches the full_shape of the kernel, and all non-bijective loads must be from smaller buffer(region)s. In the normal case, the non-bijective "small" buffers are from expands and are very small compared to the bijective ones (here it's 1/16), so adding these won't hurt. Here, fusing the diagonal will save 1 memory pass over a big buffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, for simple reduces like these from bijective shapetrackers, it should be fine to fuse many unrelated reduces. Simple reduces don't really need a lot of cache -- the cache really helps when you have expands like (1, a) * (b, 1), since you can do an nm-sized tile with only n + m loads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this may even be a real world case -- consider x and y to be the forward outputs of different layers.

@Qazalin
Copy link
Collaborator

Qazalin commented May 6, 2024

forward -> BN forward -> stuff -> fusion targets

If you fuse those targets the doesn't the cache fill up with a bunch of the "stuff" bufs? We wanna fuse if they're sharing parents.

@chaosagent
Copy link
Contributor Author

chaosagent commented May 6, 2024

If you fuse those targets the doesn't the cache fill up with a bunch of the "stuff" bufs? We wanna fuse if they're sharing parents.

we need to allow small "stuff"s (the bn backward takes some inputs from bn forward). See the argument for the bijective heuristic above

@chaosagent
Copy link
Contributor Author

chaosagent commented May 6, 2024

hm, i think one of these kernels has a superset of "stuffs" across the rest of the fusion targets. i think that makes it safe to not check the "stuffs" 🤔

actually no, it doesn't , since one of the "stuffs" that only the superset kernel has could be a descendant of the rest of the fusion targets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants