Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect code generated when contraction index is not innermost #555

Open
smr97 opened this issue Sep 7, 2023 · 0 comments
Open

Incorrect code generated when contraction index is not innermost #555

smr97 opened this issue Sep 7, 2023 · 0 comments

Comments

@smr97
Copy link

smr97 commented Sep 7, 2023

Hi, I am trying out different index orders for the following contraction:

T1[p, c, q, ] = start[p, q, r, ] * C3[c, r, ]

where T1 and start are 3D sparse tensors, and C3 is a sparse matrix. For the following index ordering, I suspect that the code generated (compute kernel) is incorrect:

Expression 1: T1[q, c, p, ] = start[q, r, p, ] * C3[r, c, ]                                                                                                   

Here is the kernel I get from TACO:

   2 int compute(taco_tensor_t *t1, taco_tensor_t *start, taco_tensor_t *c3) {                                                                               
    3   int t11_dimension = (int)(t1->dimensions[0]);                                                                                                         
    4   double* restrict t1_vals = (double*)(t1->vals);                                                                                                       
    5   int start1_dimension = (int)(start->dimensions[0]);                                                                                                   
    6   int* restrict start2_pos = (int*)(start->indices[1][0]);                                                                                              
    7   int* restrict start2_crd = (int*)(start->indices[1][1]);                                                                                              
    8   int* restrict start3_pos = (int*)(start->indices[2][0]);                                                                                              
    9   int* restrict start3_crd = (int*)(start->indices[2][1]);                                                                                              
   10   double* restrict start_vals = (double*)(start->vals);                                                                                                 
   11   int c31_dimension = (int)(c3->dimensions[0]);                                                                                                         
   12   int* restrict c32_pos = (int*)(c3->indices[1][0]);                                                                                                    
   13   int* restrict c32_crd = (int*)(c3->indices[1][1]);                                                                                                    
   14   double* restrict c3_vals = (double*)(c3->vals);                                                                                                       
   15                                                                                                                                                         
   16   int32_t pt1 = 0;                                                                                                                                      
   17                                                                                                                                                         
   18   for (int32_t q = 0; q < start1_dimension; q++) {                                                                                                      
   19     for (int32_t rstart = start2_pos[q]; rstart < start2_pos[(q + 1)]; rstart++) {                                                                      
   20       int32_t r = start2_crd[rstart];                                                                                                                   
   21       for (int32_t cc3 = c32_pos[r]; cc3 < c32_pos[(r + 1)]; cc3++) {                                                                                   
   22         for (int32_t pstart = start3_pos[rstart]; pstart < start3_pos[(rstart + 1)]; pstart++) {                                                        
   23           t1_vals[pt1] = 0.0;                                                                                                                           
   24           t1_vals[pt1] = t1_vals[pt1] + start_vals[pstart] * c3_vals[cc3];                                                                              
   25           pt1++;                                                                                                                                        
   26         }                                                                                                                                               
   27       }                                                                                                                                                 
   28     }                                                                                                                                                   
   29   }                                                                                                                                                     
   30   return 0;                                                                                                                                             
   31 }

It looks like the update to t1_vals is being done at a different position for every value of r in line 20.
Since the contraction is over r, it should in-fact be done at the same position in t1_vals.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant