Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster const-time modinv divsteps (rebase of #1031) #1197

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
148 changes: 85 additions & 63 deletions doc/safegcd_implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,14 @@ do one division by *2<sup>N</sup>* as a final step:
```python
def divsteps_n_matrix(delta, f, g):
"""Compute delta and transition matrix t after N divsteps (multiplied by 2^N)."""
u, v, q, r = 1, 0, 0, 1 # start with identity matrix
u, v, q, r = 1<<N, 0, 0, 1<<N # start with identity matrix (scaled by 2^N)
for _ in range(N):
if delta > 0 and g & 1:
delta, f, g, u, v, q, r = 1 - delta, g, (g - f) // 2, 2*q, 2*r, q-u, r-v
delta, f, g, u, v, q, r = 1 - delta, g, (g-f)//2, q, r, (q-u)//2, (r-v)//2
elif g & 1:
delta, f, g, u, v, q, r = 1 + delta, f, (g + f) // 2, 2*u, 2*v, q+u, r+v
delta, f, g, u, v, q, r = 1 + delta, f, (g+f)//2, u, v, (q+u)//2, (r+v)//2
else:
delta, f, g, u, v, q, r = 1 + delta, f, (g ) // 2, 2*u, 2*v, q , r
delta, f, g, u, v, q, r = 1 + delta, f, (g )//2, u, v, (q )//2, (r )//2
return delta, (u, v, q, r)
```

Expand Down Expand Up @@ -414,9 +414,9 @@ operations (and hope the C compiler isn't smart enough to turn them back into br
divstep can be written instead as (compare to the inner loop of `gcd` in section 1).

```python
x = -f if delta > 0 else f # set x equal to (input) -f or f
x = f if delta > 0 else -f # set x equal to (input) f or -f
if g & 1:
g += x # set g to (input) g-f or g+f
g -= x # set g to (input) g-f or g+f
if delta > 0:
delta = -delta
f += g # set f to (input) g (note that g was set to g-f before)
Expand All @@ -433,19 +433,21 @@ that *-v == (v ^ -1) - (-1)*. Thus, if we have a variable *c* that takes on valu
Using this we can write:

```python
x = -f if delta > 0 else f
x = f if delta > 0 else -f
```

in constant-time form as:

```python
c1 = (-delta) >> 63
# Compute c1=0 if delta>0 and c1=-1 if delta<=0.
c1 = (delta - 1) >> 63
# Conditionally negate f based on c1:
x = (f ^ c1) - c1
```

To use that trick, we need a helper mask variable *c1* that resolves the condition *&delta;>0* to *-1*
(if true) or *0* (if false). We compute *c1* using right shifting, which is equivalent to dividing by
To use that trick, we need a helper mask variable *c1* that resolves the condition *&delta;&leq;0* to *-1*
(if true) or *0* (if false). We compute *c1* by first subtracting *1*, which results in a negative value
if and only if *&delta;&leq;0*. That is then right shifted, which is equivalent to dividing by
the specified power of *2* and rounding down (in Python, and also in C under the assumption of a typical two's complement system; see
`assumptions.h` for tests that this is the case). Right shifting by *63* thus maps all
numbers in range *[-2<sup>63</sup>,0)* to *-1*, and numbers in range *[0,2<sup>63</sup>)* to *0*.
Expand All @@ -454,7 +456,7 @@ Using the facts that *x&0=0* and *x&(-1)=x* (on two's complement systems again),

```python
if g & 1:
g += x
g -= x
```

as:
Expand All @@ -463,7 +465,7 @@ as:
# Compute c2=0 if g is even and c2=-1 if g is odd.
c2 = -(g & 1)
# This masks out x if g is even, and leaves x be if g is odd.
g += x & c2
g -= x & c2
```

Using the conditional negation trick again we can write:
Expand All @@ -478,7 +480,7 @@ as:

```python
# Compute c3=-1 if g is odd and delta>0, and 0 otherwise.
c3 = c1 & c2
c3 = ~c1 & c2
# Conditionally negate delta based on c3:
delta = (delta ^ c3) - c3
```
Expand All @@ -497,45 +499,61 @@ becomes:
f += g & c3
```

It turns out that this can be implemented more efficiently by applying the substitution
*&eta;=-&delta;*. In this representation, negating *&delta;* corresponds to negating *&eta;*, and incrementing
*&delta;* corresponds to decrementing *&eta;*. This allows us to remove the negation in the *c1*
computation:
Putting everything together, extending all operations on f,g (with helper x) to also be applied
to u,q (with helper y) and v,r (with helper z), gives:

```python
# Compute a mask c1 for eta < 0, and compute the conditional negation x of f:
c1 = eta >> 63
x = (f ^ c1) - c1
# Compute a mask c2 for odd g, and conditionally add x to g:
c2 = -(g & 1)
g += x & c2
# Compute a mask c for (eta < 0) and odd (input) g, and use it to conditionally negate eta,
# and add g to f:
c3 = c1 & c2
eta = (eta ^ c3) - c3
f += g & c3
# Incrementing delta corresponds to decrementing eta.
eta -= 1
g >>= 1
def divsteps_n_matrix(delta, f, g):
"""Compute delta and transition matrix t after N divsteps (multiplied by 2^N)."""
u, v, q, r = 1<<N, 0, 0, 1<<N # start with identity matrix (scaled by 2^N).
for i in range(N):
c1 = (delta - 1) >> 63
# Compute x, y, z as conditionally-negated versions of f, u, v.
x, y, z = (f ^ c1) - c1, (u ^ c1) - c1, (v ^ c1) - c1
c2 = -(g & 1)
# Conditionally subtract x, y, z from g, q, r.
g, q, r = g - (x & c2), q - (y & c2), r - (z & c2)
c3 = ~c1 & c2
# Conditionally negate delta, and then increment it by 1.
delta = (delta ^ c3) - c3 + 1
# Conditionally add g, q, r to f, u, v.
f, u, v = f + (g & c3), u + (q & c3), v + (r & c3)
# Shift down g, q, r.
g, q, r = g >> 1, u >> 1, v >> 1
return delta, (u, v, q, r)
```

A variant of divsteps with better worst-case performance can be used instead: starting *&delta;* at
An interesting optimization is possible here. If we were to drop the *-c1* in the computation
of *x*, *y*, and *z*, we are making them at worst *1* less than the correct value. That
translates to *g*, *q*, and *r* further being at worst *1* more than the correct value.
Now observe that at the start of every iteration of the loop, *u*, *v*, *q*, and *r* are
all multiples of *2<sup>N-i</sub>*, with *i* the iteration number, and thus all even.
In other words, this potential off by one in *g*, *q*, and *r* only affects their bottommost
bit, which is shifted away at the end of the loop. Thus we can instead write:

```python
# Compute x, y, z as conditionally complemented versions of f, u, v.
x, y, z = f ^ c1, u ^ c1, v ^ c1
```

Finally, a variant of divsteps with better worst-case performance can be used instead: starting *&delta;* at
*1/2* instead of *1*. This reduces the worst case number of iterations to *590* for *256*-bit inputs
(which can be shown using convex hull analysis). In this case, the substitution *&zeta;=-(&delta;+1/2)*
is used instead to keep the variable integral. Incrementing *&delta;* by *1* still translates to
decrementing *&zeta;* by *1*, but negating *&delta;* now corresponds to going from *&zeta;* to *-(&zeta;+1)*, or
*~&zeta;*. Doing that conditionally based on *c3* is simply:
(which can be shown using [convex hull analysis](https://github.com/sipa/safegcd-bounds)).
In this case, the substitution *&theta;=&delta;-1/2* is used to keep the variable integral.
*&delta;&leq;0* then translates to *&theta;&leq;-1/2*, or because *&theta;* is integral, *&theta;<0*.
Thus instead of `c1 = (delta - 1) >> 63` we get `c1 = theta >> 63`.
Negating *&delta;* now corresponds to going from *&theta;* to
*-&theta;-1*. Doing that conditionally based on *c3* (and then incrementing by one) gives us:

```python
...
c3 = c1 & c2
zeta ^= c3
theta = (theta ^ c3) + 1
...
```

By replacing the loop in `divsteps_n_matrix` with a variant of the divstep code above (extended to
also apply all *f* operations to *u*, *v* and all *g* operations to *q*, *r*), a constant-time version of
`divsteps_n_matrix` is obtained. The full code will be in section 7.
`divsteps_n_matrix` is obtained. The resulting code will be in section 7.

These bit fiddling tricks can also be used to make the conditional negations and additions in
`update_de` and `normalize` constant-time.
Expand All @@ -550,7 +568,7 @@ faster non-constant time `divsteps_n_matrix` function.

To do so, first consider yet another way of writing the inner loop of divstep operations in
`gcd` from section 1. This decomposition is also explained in the paper in section 8.2. We use
the original version with initial *&delta;=1* and *&eta;=-&delta;* here.
the original version with initial *&delta;=1*, but make the substitution *&eta;=-&delta;*.

```python
for _ in range(N):
Expand Down Expand Up @@ -651,37 +669,41 @@ Here we need the negated modular inverse, which is a simple transformation of th
have this 6-bit function (based on the 3-bit function above):
- *f(f<sup>2</sup> - 2)*

This loop, again extended to also handle *u*, *v*, *q*, and *r* alongside *f* and *g*, placed in
`divsteps_n_matrix`, gives a significantly faster, but non-constant time version.
This loop, extended to also handle *u*, *v*, *q*, and *r* alongside *f* and *g*, placed in
`divsteps_n_matrix`, gives a significantly faster, but non-constant time version. In order to
avoid intermediary values that need more than N+1 bits, it is possible to instead start
*u* and *v* at *1* instead of at *2<sup>N</sup>*, and then shift up *u* and *v* whenever
*g* is shifted down (instead of shifting down *q* and *r*). This is effectively making the
algorithm operate on *i*-bits downshifted versions of all these variables. The resulting
code is shown in the next section.


## 7. Final Python version

All together we need the following functions:

- A way to compute the transition matrix in constant time, using the `divsteps_n_matrix` function
from section 2, but with its loop replaced by a variant of the constant-time divstep from
section 5, extended to handle *u*, *v*, *q*, *r*:
from section 5, modified to operate on *&theta;* instead of *&delta;*:

```python
def divsteps_n_matrix(zeta, f, g):
"""Compute zeta and transition matrix t after N divsteps (multiplied by 2^N)."""
u, v, q, r = 1, 0, 0, 1 # start with identity matrix
def divsteps_n_matrix(theta, f, g):
"""Compute theta and transition matrix t after N divsteps (multiplied by 2^N)."""
u, v, q, r = 1<<N, 0, 0, 1<<N # start with identity matrix (scaled by 2^N).
for _ in range(N):
c1 = zeta >> 63
# Compute x, y, z as conditionally-negated versions of f, u, v.
x, y, z = (f ^ c1) - c1, (u ^ c1) - c1, (v ^ c1) - c1
c1 = theta >> 63
# Compute x, y, z as conditionally complemented versions of f, u, v.
x, y, z = f ^ c1, u ^ c1, v ^ c1
c2 = -(g & 1)
# Conditionally add x, y, z to g, q, r.
g, q, r = g + (x & c2), q + (y & c2), r + (z & c2)
c1 &= c2 # reusing c1 here for the earlier c3 variable
zeta = (zeta ^ c1) - 1 # inlining the unconditional zeta decrement here
# Conditionally subtract x, y, z from g, q, r.
g, q, r = g - (x & c2), q - (y & c2), r - (z & c2)
c3 = ~c1 & c2
# Conditionally complement theta, and then increment it by 1.
theta = (theta ^ c3) + 1
# Conditionally add g, q, r to f, u, v.
f, u, v = f + (g & c1), u + (q & c1), v + (r & c1)
# When shifting g down, don't shift q, r, as we construct a transition matrix multiplied
# by 2^N. Instead, shift f's coefficients u and v up.
g, u, v = g >> 1, u << 1, v << 1
return zeta, (u, v, q, r)
f, u, v = f + (g & c3), u + (q & c3), v + (r & c3)
# Shift down f, q, r.
g, q, r = g >> 1, u >> 1, v >> 1
return theta, (u, v, q, r)
```

- The functions to update *f* and *g*, and *d* and *e*, from section 2 and section 4, with the constant-time
Expand Down Expand Up @@ -723,15 +745,15 @@ def normalize(sign, v, M):
return v
```

- And finally the `modinv` function too, adapted to use *&zeta;* instead of *&delta;*, and using the fixed
- And finally the `modinv` function too, adapted to use *&theta;* instead of *&delta;*, and using the fixed
iteration count from section 5:

```python
def modinv(M, Mi, x):
"""Compute the modular inverse of x mod M, given Mi=1/M mod 2^N."""
zeta, f, g, d, e = -1, M, x, 0, 1
theta, f, g, d, e = 0, M, x, 0, 1
for _ in range((590 + N - 1) // N):
zeta, t = divsteps_n_matrix(zeta, f % 2**N, g % 2**N)
theta, t = divsteps_n_matrix(theta, f % 2**N, g % 2**N)
f, g = update_fg(f, g, t)
d, e = update_de(d, e, t, M, Mi)
return normalize(f, d, M)
Expand All @@ -745,7 +767,7 @@ def modinv(M, Mi, x):
NEGINV16 = [15, 5, 3, 9, 7, 13, 11, 1] # NEGINV16[n//2] = (-n)^-1 mod 16, for odd n
def divsteps_n_matrix_var(eta, f, g):
"""Compute eta and transition matrix t after N divsteps (multiplied by 2^N)."""
u, v, q, r = 1, 0, 0, 1
u, v, q, r = 1, 0, 0, 1 # Start with identity matrix (not scaled; shift during run instead).
i = N
while True:
zeros = min(i, count_trailing_zeros(g))
Expand Down