Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split UnaryOps.CAST into CAST and BITCAST #4487

Merged
merged 7 commits into from May 15, 2024

Conversation

wpmed92
Copy link
Contributor

@wpmed92 wpmed92 commented May 9, 2024

This PR introduces UnaryOps.BITCAST, and it is used instead of UnaryOps.CAST and bitcast=true boolean flag combination.

a = Tensor([1,2,3], dtype=dtypes.float).bitcast(dtypes.int).realize()

before

  0 ━┳ STORE MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))
  1  ┗━┳ CAST (dtypes.int, True)
  2    ┗━━ LOAD MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))
((LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=(dtypes.int, True)),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), [Opt(op=OptOps.LOCAL, axis=0, amt=3)])

after

 0 ━┳ STORE MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))
 1  ┗━┳ BITCAST dtypes.int
 2    ┗━━ LOAD MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))
((LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.BITCAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.int),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), [Opt(op=OptOps.LOCAL, axis=0, amt=3)])

@wpmed92 wpmed92 changed the title Separate cast and bitcast Separate UnaryOps.CAST into CAST and BITCAST May 9, 2024
@wpmed92 wpmed92 changed the title Separate UnaryOps.CAST into CAST and BITCAST Split UnaryOps.CAST into CAST and BITCAST May 9, 2024
@geohot geohot added the bounty locked Bounty is locked to someone label May 9, 2024
@geohot
Copy link
Collaborator

geohot commented May 9, 2024

Good to see you back!

@geohot
Copy link
Collaborator

geohot commented May 9, 2024

arg of cast should be changed to only be the dtype, we don't need the True/False anymore since CAST/BITCAST are different ops

@wpmed92 wpmed92 force-pushed the bitcast-unary-op branch 3 times, most recently from 99157dd to fbb09ae Compare May 11, 2024 07:35
@wpmed92 wpmed92 force-pushed the bitcast-unary-op branch 2 times, most recently from 6b3b035 to b3ddb4b Compare May 12, 2024 12:30
@wpmed92
Copy link
Contributor Author

wpmed92 commented May 12, 2024

@Qazalin @geohot

  • cast arg is now just the dtype (it's not a tuple anymore)
  • no bitcast on ImageDType
  • regenerated dataset so that sops.gz has the new UnaryOps.CAST and UnaryOps.BITCAST in the ASTs (for this I also had to update the UNROLL action in search, because I got a wrong action for UNROLL, axis=4 for the new dataset)

@wpmed92 wpmed92 requested a review from Qazalin May 12, 2024 12:40
tinygrad/codegen/kernel.py Outdated Show resolved Hide resolved
tinygrad/lazy.py Outdated Show resolved Hide resolved
tinygrad/codegen/kernel.py Outdated Show resolved Hide resolved
@wpmed92 wpmed92 force-pushed the bitcast-unary-op branch 3 times, most recently from 71580b9 to ce5a55f Compare May 13, 2024 07:08
tinygrad/codegen/kernel.py Outdated Show resolved Hide resolved
@Qazalin
Copy link
Collaborator

Qazalin commented May 13, 2024

nice - new graph:
image

I'm not sure about the wrong UNROLL action though - What was the error?

@wpmed92
Copy link
Contributor Author

wpmed92 commented May 13, 2024

I was getting this for PYTHONPATH="." GPU=1 python3 extra/optimization/get_action_space.py:

Screenshot 2024-05-13 at 09 32 48

Copy link
Contributor

Changes

Name                              Lines    Diff    Tokens/Line    Diff
------------------------------  -------  ------  -------------  ------
tinygrad/codegen/kernel.py          460      +0           18.4    -0.0
tinygrad/codegen/linearizer.py      337      +0           19.3    +0.0
tinygrad/engine/schedule.py         243      +0           14.0    -0.0
tinygrad/lazy.py                    160      +1           18.9    -0.2
tinygrad/ops.py                     108      +1           17.4    +0.2


total lines changes: +2

@Qazalin
Copy link
Collaborator

Qazalin commented May 13, 2024

did you try bumping the DB version for the actions issue? you might just need to revalidate the sqlite db instead of changing action space script.

@wpmed92
Copy link
Contributor Author

wpmed92 commented May 13, 2024

Still wrong after db bump, only seems to be solved by bumping axis in unroll.

@chenyuxyz
Copy link
Collaborator

the error is saying the handcoded opt had used a action that's not in the search action lists. can you print the kernel that caused this?

@wpmed92
Copy link
Contributor Author

wpmed92 commented May 15, 2024

@chenyuxyz

Screenshot 2024-05-15 at 11 03 16

(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2, 3, 4, 5, 6, 1, 2, 5, 6, 5, 3, 4), strides=(0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2, 3, 4, 5, 6, 1, 2, 5, 6, 5, 3, 4), strides=(1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2, 3, 4, 5, 6, 1, 2, 5, 6, 5, 3, 4), strides=(0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2, 3, 4, 5, 6, 1, 2, 5, 6, 5, 3, 4), strides=(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=5, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2, 3, 4, 5, 6, 1, 2, 5, 6, 5, 3, 4), strides=(0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=6, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2, 3, 4, 5, 6, 1, 2, 5, 6, 5, 3, 4), strides=(0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2, 3, 4, 5, 6, 1, 2, 5, 6, 5, 3, 4), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=7, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2, 3, 4, 5, 6, 1, 2, 5, 6, 5, 3, 4), strides=(5, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=8, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2, 3, 4, 5, 6, 1, 2, 5, 6, 5, 3, 4), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=9, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(2, 3, 4, 5, 6, 1, 2, 5, 6, 5, 3, 4), strides=(0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=10, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 3, 4, 5, 6, 1, 2, 5, 6, 5, 3, 4), strides=(1440, 480, 120, 24, 4, 0, 0, 0, 0, 0, 0, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0, 1, 2, 3, 4)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 1, 2, 5, 6, 5, 3, 4), strides=(0, 0, 0, 0, 0, 0, 1800, 360, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)))),)

@chenyuxyz
Copy link
Collaborator

it triggered a new tc, might be related to #4427. @flammit fyi

@chenyuxyz chenyuxyz merged commit 662bca8 into tinygrad:master May 15, 2024
19 checks passed
@chenyuxyz
Copy link
Collaborator

congrats! george@tinygrad.org to claim

@wpmed92 wpmed92 deleted the bitcast-unary-op branch May 15, 2024 17:50
dimaheve pushed a commit to dimaheve/tinygrad that referenced this pull request May 15, 2024
* Separate cast and bitcast

* Fix lint

* No more arg[0]

* Revert "No more arg[0]"

This reverts commit dee6911.

* CAST/BITCAST arg is the dtype only, no more tuple

* No image bitcast, regenerate dataset

* Small fixes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bounty locked Bounty is locked to someone
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants