Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto Difference with Matrix Calculation #8498

Open
LarkLeeOnePiece opened this issue Mar 27, 2024 · 0 comments
Open

Auto Difference with Matrix Calculation #8498

LarkLeeOnePiece opened this issue Mar 27, 2024 · 0 comments
Labels
question Question on using Taichi

Comments

@LarkLeeOnePiece
Copy link

My task is to test the auto gradient with the matrix operation.

Here are my code:
`import taichi as ti
import math
ti.init(arch=ti.gpu,debug=True)

N = 10
Dim3=3
Dim2=2
Dim1=1
positions = ti.Vector.field(Dim2, dtype=ti.f32, shape=(N,),needs_grad=True)# define a 2D vector field
pos2Ds=ti.Vector.field(Dim1, dtype=ti.f32, shape=(N,),needs_grad=True)
realpos=ti.field(dtype=ti.f32, shape=(N,),needs_grad=True)
L=ti.field(dtype=ti.f32, shape=(), needs_grad=True)

theta = math.pi / 4 # rotate 45 degree
rotation_matrix = ti.Matrix([[ti.cos(theta), -ti.sin(theta)], [ti.sin(theta), ti.cos(theta)]])# Define a rotation matrix

@ti.kernel
def init():
for i in ti.grouped(positions):
positions[i] = [1.0,1.0]
@ti.kernel
def transform():
for i in positions:
pos2Ds[i] = (rotation_matrix @ positions[i]).y
@ti.kernel
def comp_loss():
for i in ti.grouped(pos2Ds):
L[None]+=(realpos[i]-pos2Ds[i].x)

init()
with ti.ad.Tape(loss=L,validation=True):
transform()
comp_loss(
)
print(L[None])
print(positions.grad)`

But I got the error:

RuntimeError Traceback (most recent call last)
Cell In[1], line 36
33 transform()
34 # Kernel invocations in this scope will later contribute to partial derivatives of
35 # U with respect to input variables such as x.
---> 36 comp_loss(
37 ) # The tape will automatically compute dU/dx and save the results in x.grad
38 print(L[None])
39 print(positions.grad)

File c:\Users\LID0E\AppData\Local\miniconda3\envs\Gaussin-Taichi\lib\site-packages\taichi\lang\kernel_impl.py:1103, in _kernel_impl..wrapped(*args, **kwargs)
1100 @functools.wraps(_func)
1101 def wrapped(*args, **kwargs):
1102 try:
-> 1103 return primal(*args, **kwargs)
1104 except (TaichiCompilationError, TaichiRuntimeError) as e:
1105 if impl.get_runtime().print_full_traceback:

File c:\Users\LID0E\AppData\Local\miniconda3\envs\Gaussin-Taichi\lib\site-packages\taichi\lang\shell.py:27, in _shell_pop_print..new_call(*args, **kwargs)
25 @functools.wraps(old_call)
26 def new_call(*args, **kwargs):
---> 27 ret = old_call(*args, **kwargs)
28 # print's in kernel won't take effect until ti.sync(), discussion:
29 # #1303 (comment)
30 print(_ti_core.pop_python_print_buffer(), end="")

File c:\Users\LID0E\AppData\Local\miniconda3\envs\Gaussin-Taichi\lib\site-packages\taichi\lang\kernel_impl.py:1035, in Kernel.call(self, *args, **kwargs)
1033 key = self.ensure_compiled(*args)
1034 kernel_cpp = self.compiled_kernels[key]
-> 1035 return self.launch_kernel(kernel_cpp, *args)

File c:\Users\LID0E\AppData\Local\miniconda3\envs\Gaussin-Taichi\lib\site-packages\taichi\lang\kernel_impl.py:966, in Kernel.launch_kernel(self, t_kernel, *args)
964 if impl.get_runtime().print_full_traceback:
965 raise e
--> 966 raise e from None
968 ret = None
969 ret_dt = self.return_type

File c:\Users\LID0E\AppData\Local\miniconda3\envs\Gaussin-Taichi\lib\site-packages\taichi\lang\kernel_impl.py:959, in Kernel.launch_kernel(self, t_kernel, *args)
957 prog = impl.get_runtime().prog
958 # Compile kernel (& Online Cache & Offline Cache)
--> 959 compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
960 # Launch kernel
961 prog.launch_kernel(compiled_kernel_data, launch_ctx)

RuntimeError: [taichi/ir/ir.h:taichi::lang::IRNode::as@248] Assertion failure: is()

How can I fix this problem?

@LarkLeeOnePiece LarkLeeOnePiece added the question Question on using Taichi label Mar 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Question on using Taichi
Projects
Status: Untriaged
Development

No branches or pull requests

1 participant