Skip to content

Commit

Permalink
openpilot fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Mar 6, 2023
1 parent 4b9bc16 commit d8dda2a
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 55 deletions.
48 changes: 0 additions & 48 deletions extra/thneed.py
Expand Up @@ -290,51 +290,3 @@ def run(self):
print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms")
return total_runtime/1e9
return et

def optimize_local_workgroup(self):
MAX_WORKGROUP = CL.cl_ctx.devices[0].max_work_group_size
local_cl_cache = []
for prg, args in self.cl_cache:
potential_locals = [tuple(args[1])] if args[1] is not None else []
runtimes = []
args = list(args)

# NOTE: if args[1] is not None, it may use local variables and you shouldn't change this
if args[1] is None and len(args[0]) == 1:
for l1 in [args[0][0], 1, 4, 16, MAX_WORKGROUP//4, MAX_WORKGROUP]:
potential_locals.append((l1,))

if args[1] is None and len(args[0]) == 2:
for l2 in [1, 4, 16, MAX_WORKGROUP//4, MAX_WORKGROUP]:
potential_locals.append((min(MAX_WORKGROUP, args[0][0]), l2))

if args[1] is None and len(args[0]) == 3:
for l2 in [16,args[0][1],MAX_WORKGROUP]:
for l3 in [4,16,args[0][2],MAX_WORKGROUP]:
for l1 in [max(1, MAX_WORKGROUP//(l2*l3)), args[0][0], 4, 16, MAX_WORKGROUP]:
if l1 > args[0][0] or l2 > args[0][1] or l3 > args[0][2]: continue
potential_locals.append((l1, l2, l3))

for local_args in potential_locals:
if prod(local_args) > MAX_WORKGROUP: continue
args[1] = local_args
# 3 runs just in case
for i in range(3):
try:
e = prg.clprg(CL.cl_queue, *args)
except (cl.LogicError, cl.RuntimeError):
# INVALID_WORK_GROUP_SIZE
continue
CL.cl_queue.finish()
runtime = e.profile.end - e.profile.start
#print(runtime, args[0], args[1])
runtimes.append((runtime, local_args))

if len(runtimes) > 0:
args[1] = sorted(runtimes)[0][1]
else:
args[1] = None
print("couldn't optimize", args[0])

local_cl_cache.append((prg, args))
self.cl_cache = local_cl_cache
3 changes: 0 additions & 3 deletions openpilot/compile.py
Expand Up @@ -79,9 +79,6 @@ def compile(dat, output_fn):
from extra.thneed import Thneed
t = Thneed(cl_cache, {k:v._cl for k,v in input_rawbuffers.items()})

if getenv("OPTWG", 0):
t.optimize_local_workgroup()

# save thneed (before run)
t.save(output_fn)

Expand Down
3 changes: 1 addition & 2 deletions openpilot/go.sh
@@ -1,3 +1,2 @@
#!/bin/bash
FLOAT16=1 DEBUGCL=1 NATIVE_EXPLOG=1 VALIDHACKS=1 OPTWG=1 IMAGE=2 GPU=1 CLCACHE=0 python3 openpilot/compile.py

FLOAT16=1 DEBUGCL=1 NATIVE_EXPLOG=1 VALIDHACKS=1 OPTLOCAL=1 IMAGE=2 GPU=1 ENABLE_METHOD_CACHE=1 python3 openpilot/compile.py
6 changes: 4 additions & 2 deletions tinygrad/image.py
@@ -1,4 +1,5 @@
from tinygrad.helpers import IMAGE
from tinygrad.lazy import get_single_root

def image_conv2d_decorator(normal_conv):
if IMAGE == 0: return normal_conv
Expand Down Expand Up @@ -32,8 +33,9 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3).reshape(cout//4, H*cin//4*W*4, 4)
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1).reshape(cout//4, H*cin//4*W*4, 4)

# contiguous creates the image, and early realize static weights (TODO: don't always realize)
x, w = x.contiguous(), w.contiguous().realize()
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
x, w = x.contiguous(), w.contiguous()
if get_single_root(w.lazydata).realized: w.realize()

# expand out
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
Expand Down

0 comments on commit d8dda2a

Please sign in to comment.