Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix.dot is not merged #1407

Closed
Vinc0682 opened this issue May 15, 2024 · 4 comments
Closed

Matrix.dot is not merged #1407

Vinc0682 opened this issue May 15, 2024 · 4 comments

Comments

@Vinc0682
Copy link

Hello,

I am currently developing a protocol which uses vector-matrix multiplications and I've noticed some unexpected behavior, produced by the following example code:

a = Array.create_from([sint(1), sint(2), sint(3)])
b = Array.create_from([sint(3), sint(2), sint(1)])

c = Matrix.create_from([
    [sint(1), sint(2), sint(3)],
    [sint(4), sint(5), sint(6)],
    [sint(7), sint(8), sint(9)]
])

d = Matrix.create_from([
    [sint(9), sint(8), sint(7)],
    [sint(6), sint(5), sint(4)],
    [sint(1), sint(2), sint(3)]
])

break_point()


def manual_array_dot_matrix(arr: Array, mat: Matrix) -> Array:
    assert len(arr) == mat.shape[0]

    result = sint.Array(mat.shape[1])
    for i in range(len(arr)):
        result[:] += mat[i][:] * arr[i]
    return result


def dot_array_dot_matrix(arr: Array, mat: Matrix) -> Array:
    # Arrays sadly do not have a dot function, therefore the array is converted into a 1 times n Matrix by copying data.
    tmp = sint.Matrix(rows=1, columns=len(arr))
    tmp[:] = arr[:]
    tmp = tmp.dot(mat)

    result = sint.Array(mat.shape[1])
    result[:] = tmp[:]
    return result


def hacky_array_dot_matrix(arr: Array, mat: Matrix) -> Array:
    # Arrays sadly do not have a dot function, therefore the array is converted into a 1 times n Matrix by copying register pointers.
    tmp = sint.Matrix(rows=1, columns=len(arr), address=arr.address)
    result = tmp.dot(mat)
    return sint.Array(mat.shape[1], result.address)


def custom_matmuls(arr: Array, mat: Matrix) -> Array:
    assert len(arr) == mat.shape[0]

    result = sint.new_vector(mat.shape[1])
    matmuls(result, arr[:], mat[:], 1, len(arr), mat.shape[1])
    return Array.create_from(result)


start_timer(1)

e1 = manual_array_dot_matrix(a, c)
f1 = manual_array_dot_matrix(b, d)

stop_timer(1)

e1.print_reveal_nested()
f1.print_reveal_nested()

start_timer(2)

e2 = dot_array_dot_matrix(a, c)
f2 = dot_array_dot_matrix(b, d)

stop_timer(2)

e2.print_reveal_nested()
f2.print_reveal_nested()

start_timer(3)

e3 = hacky_array_dot_matrix(a, c)
f3 = hacky_array_dot_matrix(b, d)

stop_timer(3)

e3.print_reveal_nested()
f3.print_reveal_nested()

start_timer(4)

e4 = custom_matmuls(a, c)
f4 = custom_matmuls(b, d)

stop_timer(4)

e4.print_reveal_nested()
f4.print_reveal_nested()

When executed using replicated ring (of MP-SPDZ version 0.3.8 as well as the current master branch), this yields the following outputs:

Time1 = 9.0682e-05 seconds (0.000144 MB, 1 rounds)
Time2 = 0.000137906 seconds (4.8e-05 MB, 2 rounds)
Time3 = 0.000143426 seconds (4.8e-05 MB, 2 rounds)
Time4 = 4.5254e-05 seconds (4.8e-05 MB, 1 rounds)

As shown, the vector-matrix multiplications are not merged when using the Matrix.dot function, which would require me to use the first approach, which drastically increases the compile time, or the very hacky fourth method.
Both solutions don't seem to be very good. Further, I can see others (and my future self) wrongfully relying on the dot function being parallelized / mergeable.

I have identified two ways this could be achieved:

  • Changing the dot function to use the MATMULS instruction, which I assume is not used for a good reason?
  • Modifying the MATMULSM instruction to be mergable, which AFAIK would require changes to the VM and the bytecode.

Is it intended that the dot function is not merged? If not, which way to fix this issue would you recommend?

@mkskeller
Copy link
Member

The reason to have matmulsm is to avoid loading all inputs to registers (the m at the end stands for memory). The only reason not to have matmulsm mergeable is the effort involved. You should be able to force dot() to use matmuls using del sint.direct_matrix_mul.

@Vinc0682
Copy link
Author

Thanks for the swift response. This works for now.
For the case that I (or a student assistant) would attempt to make matmulsm mergeable, how would the general steps look like?

My guess would be that it should be turned into varargs instructions, followed by adjustments in the SubProcessor (and to the Hemi protocol). How would this be handled on the compiler side?

@mkskeller
Copy link
Member

You're correct abouth the virtual machine side. The compiler side would relatively easy. You would need to adapt arg_format to cycling (see matmuls for an example) and make matmulsm a sub-class of base.Mergeable.

@Vinc0682
Copy link
Author

Thank you very much. I may attempt making matmulsm mergeable once I have the time to do so.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants