Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing fast modular exponentiation - a guide #126

Open
mratsim opened this issue May 19, 2023 · 2 comments
Open

Implementing fast modular exponentiation - a guide #126

mratsim opened this issue May 19, 2023 · 2 comments

Comments

@mratsim
Copy link
Contributor

mratsim commented May 19, 2023

The goal of this issue is to provide a guideline on how to fix status-im/nimbus-eth1#1584.

Also pinging @treeform, @guzba on how to implement fast RSA (modexp is the bottleneck)
and @dlesnoff for nim/bigint

Here is a write-up on how to implement fast modular exponentiation.

Recommended textbook for implementer:

We assume 64-bit words.

Vocabulary

In textbook, you'll encounter the term "reduction" which should be read "modular reduction" and just means the remainder after Euclidean division.

The assembly myth

First of all let's dispel a myth, that assembly is the key for speed.
That's true but missing the big picture.

Some benchmarks on Constantine modular exponentiation on BN254 prime field by a 254-bit integer:

  • GCC 13.1.1 + assembly
    • constant-time 131147.541 ops/s, 7625 ns/op, 25180 CPU cycles
    • variable-time 182882.224 ops/s, 5468 ns/op, 18056 CPU cycles
  • GCC 13.1.1 - no assembly
    • constant-time 87351.502 ops/s, 11448 ns/op, 37805 CPU cycles
    • variable-time 103842.160 ops/s, 9630 ns/op, 31801 CPU cycles
  • Clang 15.0.7 + assembly
    • constant-time 131268.049 ops/s, 7618 ns/op, 25157 CPU cycles
    • variable-time 184638.109 ops/s, 5416 ns/op, 17887 CPU cycles
  • Clang 15.0.7 - no assembly
    • constant-time 116726.976 ops/s, 8567 ns/op, 28291 CPU cycles
    • variable-time 143513.203 ops/s, 6968 ns/op, 23012 CPU cycles

So the ratio assembly/no assembly is "only" 30% with Clang.
And don't use GCC with bigints, it's just bad if you don't use assembly.

The big picture

Modular exponentiation is implemented through an algorithm called double-and-add (or for exponentiation multiply-and-square), which does as many modular squarings as the number of bits in the exponent and as many modular multiplications as the number of set bits in the exponents.

For random numbers, about 50% of the bits are set. Assuming 256-bits, that's 384 modular multiplications/squarings.

Each modular multiplication is naively a multiplication 256-bit x 256-bit -> 512-bit and then modulo a 256-bit number.

The bottleneck

Let's take https://www.agner.org/optimize/instruction_tables.pdf and have a look at the speed of
the DIV instruction, which is necessary to compute modulo.

x86 started to be extremely optimized for BigInt after Broadwell which introduced ADCX and ADOX (and MULX was introduced in Haswell, Broadwell predecessor).

DIV on 64-bit input takes 36 cycles, and has a latency of up to 95 cycles (i.e. anything that depends on that result may wait up to 95 cycles before proceeding).

image

In comparison add and shifts are just 1 cycle. So anything that uses division starts with a heavy disadvantage.
image
image

Note: that disadvantage is still faster than doing bit-by-bit division like here (the algorithm is chosen if there is at most 8-bit of length difference between operands)

func divmodBS(x, y: UintImpl, q, r: var UintImpl) =
## Division for multi-precision unsigned uint
## Implementation through binary shift division
doAssert y.isZero.not() # This should be checked on release mode in the divmod caller proc
type SubTy = type x.lo
var
shift = y.leadingZeros - x.leadingZeros
d = y shl shift
r = x
while shift >= 0:
q += q
if r >= d:
r -= d
q.lo = q.lo or one(SubTy)
d = d shr 1
dec(shift)
const BinaryShiftThreshold = 8 # If the difference in bit-length is below 8
# binary shift is probably faster

Back of the napkin perf:

A 256-bit modular reduction will need 4 64-bit DIV (and other things). So we already have a cost of about 400 cycles.
We need that on 384 operations, so 384x400 = 153600 cycles. That's over 5x more costly that my slow benchmark of GCC without assembly (well it's on 254-bit instead of 256-bit)

And there is a lot more work beside the divisions, see

func div2n1n[T: SomeUnsignedInt](q, r: var T, n_hi, n_lo, d: T) =
# doAssert leadingZeros(d) == 0, "Divisor was not normalized"
const
size = bitsof(q)
halfSize = size div 2
halfMask = (1.T shl halfSize) - 1.T
template halfQR(n_hi, n_lo, d, d_hi, d_lo: T): tuple[q,r: T] =
var (q, r) = divmod(n_hi, d_hi)
let m = q * d_lo
r = (r shl halfSize) or n_lo
# Fix the reminder, we're at most 2 iterations off
if r < m:
dec q
r += d
if r >= d and r < m:
dec q
r += d
r -= m
(q, r)
let
d_hi = d shr halfSize
d_lo = d and halfMask
n_lohi = n_lo shr halfSize
n_lolo = n_lo and halfMask
# First half of the quotient
let (q1, r1) = halfQR(n_hi, n_lohi, d, d_hi, d_lo)
# Second half
let (q2, r2) = halfQR(r1, n_lolo, d, d_hi, d_lo)
q = (q1 shl halfSize) or q2
r = r2

i.e. all the example implementations on Wikipedia are really slow: https://en.wikipedia.org/wiki/Modular_arithmetic#Example_implementations

How to avoid division

There are 2 main techniques to avoid costly divisions:

Barret reduction

https://en.wikipedia.org/wiki/Barrett_reduction

Instead of doing a*b mod m you do a*b*(2⁶⁴)ᵏ/m and then you shift by k words (i.e. divide by 2⁶⁴)
This is called Barret reduction and is interesting when (2⁶⁴)ᵏ/m can be reused many times.
k is chosen so that the division by m has an inconsequential rounding error.

Montgomery reduction

Montgomery reduction uses a similar approach to Barret reduction but with lower complexity
at the price of the need to "transport"/convert the number being reduced to the "Montgomery domain"

In practice we do all computation on a' = aR (mod m) with R = (2⁶⁴)^(numWords) (mod m), numWords = 4 for 256-bit numbers.

Once we have numbers in the Montgomery domain, there is an operation called Montgomery modular multiplication (montMul) that does: montMul(a', b') = montMul(aR, bR) = abR (mod m)

montMul is the fastest modular multiplication algorithm that works on almost all moduli.

Why almost?

Well, it works on odd moduli and all primes besides 2 are odd.

Anyway, if we have an odd modulus, we can compute the Montgomery modular exponentiation instead of Modular exponentiation.

Converting to-from the Montgomery domain only requires montMul by R² or by 1.

Reconciling Montgomery and even modulus.

One issue is that in the Ethereum Virtual Machine, modexp can receive any modulus, not just odd.

Thankfully, we can invoke the Chinese Remainder Theorem (CRT) (https://en.wikipedia.org/wiki/Chinese_remainder_theorem), that states that if your modulus m = a * b with a and b coprimes, you can compute mod a and mod b separately and it gives you a way to recombine mod m

So if you have an even number, you split it into a = 2ᵏ an even power of 2, and b an odd number.
Odd numbers are coprimes with power of 2 so you can for sure apply the CRT.

  • Doing modulo a power of 2 is very easy, x mod 2ᵏ == x and (2ᵏ - 1).
  • And for the odd number you have Montgomery.

Engineering & Implementation

Now that we have the theory, let's look at engineering problems and reference code.

Fast Montgomery multiplication

There are many ways to multiply bigints, there were categorized in Acar thesis: https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf

  • Separated Operand Scanning (SOS): multiprecision multiplication and Montgomery reduction are done separately
  • Coarsely Integrated Operand Scanning (CIOS): This interleaves schoolbook multiplication and Montgomery reduction
  • Finely Integrated Operand Scanning (FIOS): similar to previous
  • Finely Integrated Product Scanning (FIPS): When doing schoolbook multiplication you can compute rows-by-rows or columns-by-columns, this does columns-by-columns.
  • Coarsely integrated Hybrid Scanning (CIHS)

On modern architectures, with MULX/ADCX/ADOX, CIOS is the fastest. without MULX/ADCX/ADOX (and so potentially on everything besides x86), FIPS is the fastest. Why? AFAIK data movement is better and less carries to save/restore.

The no-assembly algorithm for both is available here:

Fast large multiplication

One issue with multiplication is that it's O(n²) regarding the number of words since we need to multiply each word in each multiplicand with each word in the other.

The Karatsuba algorithm has a complexity of about O(n¹˙⁵) with a constant factor that becomes negligeable at around 8~12 words (so 512-768, to be measured), https://en.wikipedia.org/wiki/Karatsuba_algorithm

Fast exponentiation

Now we can look into exponentiation. The basic algorithm is square-and-multiply https://en.wikipedia.org/wiki/Exponentiation_by_squaring

You scan the bits of the exponent (from left-to-right, i.e. MSB-to-LSB or right to left, i.e. LSB-to-MSB, both variants are possible), you always square and then you multiply by something (depending on your scanning direction) if the bit if set.

However, you can precompute bit pattern, for example
1101, 1001, 0111 and instead of doing 2 or 3 multiplication, ensure that you only do 1 every 4 squarings.

That's called the window method.

An example, fixed window method is available in Constantine, . As there is no constant-time requirement in Stint it can be simplified, but I needed to ensure I could use the value of secret bits for windows (for example RSA uses modular exponentiation) without revealing what the secret bits are. https://github.com/mratsim/constantine/blob/1c5341f/constantine/math/arithmetic/limbs_montgomery.nim#L614-L836

Sliding window

An extra optimization is using window of variable sizes called sliding window. I don't have an implementation of that (cannot be made constant-time) but Wikipedia has a pseudocode https://en.wikipedia.org/wiki/Exponentiation_by_squaring#Sliding-window_method

NAF/signed recoding

In your research you might come across NAF or signed recoding, this only applies to elliptic curves (because inversion 1/a (mod m) is not cheap in modular arithmetic but -P is cheap in elliptic curve)

Montgomery domain, conversion and constants

Computing R and R² can be done like this: https://github.com/mratsim/constantine/blob/1c5341f/constantine/math/config/precompute.nim#L307-L350

You will also need 1/M0 (mod 2⁶⁴) with M0 being the first limb of your modulus M.

Wrapup

The implementation steps should be:

  • Implement Montgomery Modular Multiplication montMul
  • Compute Montgomery constants R and R²
    • implement conversion to Montgomery domain with montMul(a, R²) = a' = aR (mod m)
    • implement conversion from Montgomery domain with montMul(a', 1) = a (mod m)
  • implement exponentiation by squaring
  • Optional
    • implement window method
    • implement assembly

and use Clang

@mratsim
Copy link
Contributor Author

mratsim commented May 29, 2023

I have a local implementation based on Constantine (unfuzzed yet, only basic testing)

Performance improvement from 53.7x to 82.9x on 256 bits inputs
image

bench code

import
  ../constantine/math/arithmetic,
  ../constantine/math/io/[io_bigints, io_fields],
  ../constantine/math_arbitrary_precision/arithmetic/[bigints_views, limbs_views, limbs_montgomery, limbs_mod2k],
  ../constantine/math/config/[type_bigint, curves, precompute],
  ../constantine/platforms/[abstractions, codecs],
  ../helpers/prng_unsafe,
  std/[times, monotimes, strformat]

import stint

# let M = Mod(BN254_Snarks)
const bits = 256
const expBits = bits # Stint only supports same size args

var rng: RngState
rng.seed(1234)

for i in 0 ..< 5:
  echo "\ni: ", i
  # -------------------------
  let M = rng.random_long01Seq(BigInt[bits])
  let a = rng.random_long01Seq(BigInt[bits])

  var exponent = newSeq[byte](expBits div 8)
  for i in 0 ..< expBits div 8:
    exponent[i] = byte rng.next()

  # -------------------------

  let aHex = a.toHex()
  let eHex = exponent.toHex()
  let mHex = M.toHex()

  echo "  base:     ", a.toHex()
  echo "  exponent: ", exponent.toHex()
  echo "  modulus:  ", M.toHex()

  # -------------------------

  var elapsedCtt, elapsedStint: int64

  block:
    var r: BigInt[bits]
    let start = getMonotime()
    r.limbs.powMod_vartime(a.limbs, exponent, M.limbs, window = 4)
    let stop = getMonotime()

    elapsedCtt = inNanoseconds(stop-start)

    echo "  r Constantine:       ", r.toHex()
    echo "  elapsed Constantine: ", elapsedCtt, " ns"

  # -------------------------

  block:
    let aa = Stuint[bits].fromHex(aHex)
    let ee = Stuint[expBits].fromHex(eHex)
    let mm = Stuint[bits].fromHex(mHex)

    var r: Stuint[bits]
    let start = getMonotime()
    r = powmod(aa, ee, mm)
    let stop = getMonotime()

    elapsedStint = inNanoseconds(stop-start)

    echo "  r stint:             ", r.toHex()
    echo "  elapsed Stint:       ", elapsedStint, " ns"

  echo &"\n  ratio Stint/Constantine: {float64(elapsedStint)/float64(elapsedCtt):.3f}x"
  echo "---------------------------------------------------------"
  • Compiler is clang, and compiled with -d:danger. GCC is very bad at multiprecision arithmetic.
  • No assembly is used, it's a whole new arbitrary precision integer backend. There is likely ~30% perf to gain over Clang (and probably 80% over GCC).
  • a constant-time modulo is used. Unsure about the perf impact at the moment.
  • implementing sliding window for exponentiation might gain 15~20% (https://en.wikipedia.org/wiki/Exponentiation_by_squaring#Sliding-window_method)

@mratsim
Copy link
Contributor Author

mratsim commented May 29, 2023

Redid the bench vs GMP, Constantine is +/- 10% slower or faster (without assembly!)

image

bench

import
  ../constantine/math/arithmetic,
  ../constantine/math/io/[io_bigints, io_fields],
  ../constantine/math_arbitrary_precision/arithmetic/[bigints_views, limbs_views, limbs_montgomery, limbs_mod2k],
  ../constantine/math/config/[type_bigint, curves, precompute],
  ../constantine/platforms/[abstractions, codecs],
  ../helpers/prng_unsafe,
  std/[times, monotimes, strformat]

import stint, gmp

const # https://gmplib.org/manual/Integer-Import-and-Export.html
  GMP_WordLittleEndian = -1'i32
  GMP_WordNativeEndian = 0'i32
  GMP_WordBigEndian = 1'i32

  GMP_MostSignificantWordFirst = 1'i32
  GMP_LeastSignificantWordFirst = -1'i32

# let M = Mod(BN254_Snarks)
const bits = 256
const expBits = bits # Stint only supports same size args

var rng: RngState
rng.seed(1234)

for i in 0 ..< 5:
  echo "i: ", i
  # -------------------------
  let M = rng.random_long01Seq(BigInt[bits])
  let a = rng.random_long01Seq(BigInt[bits])

  var exponent = newSeq[byte](expBits div 8)
  for i in 0 ..< expBits div 8:
    exponent[i] = byte rng.next()

  # -------------------------

  let aHex = a.toHex()
  let eHex = exponent.toHex()
  let mHex = M.toHex()

  echo "  base:     ", a.toHex()
  echo "  exponent: ", exponent.toHex()
  echo "  modulus:  ", M.toHex()

  # -------------------------

  var elapsedCtt, elapsedStint, elapsedGMP: int64

  block:
    var r: BigInt[bits]
    let start = getMonotime()
    r.limbs.powMod_vartime(a.limbs, exponent, M.limbs, window = 4)
    let stop = getMonotime()

    elapsedCtt = inNanoseconds(stop-start)

    echo "  r Constantine:       ", r.toHex()
    echo "  elapsed Constantine: ", elapsedCtt, " ns"

  # -------------------------

  block:
    let aa = Stuint[bits].fromHex(aHex)
    let ee = Stuint[expBits].fromHex(eHex)
    let mm = Stuint[bits].fromHex(mHex)

    var r: Stuint[bits]
    let start = getMonotime()
    r = powmod(aa, ee, mm)
    let stop = getMonotime()

    elapsedStint = inNanoseconds(stop-start)

    echo "  r stint:             ", r.toHex()
    echo "  elapsed Stint:       ", elapsedStint, " ns"

  block:
    var aa, ee, mm, rr: mpz_t
    mpz_init(aa)
    mpz_init(ee)
    mpz_init(mm)
    mpz_init(rr)

    aa.mpz_import(a.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, a.limbs[0].unsafeAddr)
    let e = BigInt[expBits].unmarshal(exponent, bigEndian)
    ee.mpz_import(e.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, e.limbs[0].unsafeAddr)
    mm.mpz_import(M.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, M.limbs[0].unsafeAddr)

    let start = getMonotime()
    rr.mpz_powm(aa, ee, mm)
    let stop = getMonotime()

    elapsedGMP = inNanoSeconds(stop-start)

    var r: BigInt[bits]
    var rWritten: csize
    discard r.limbs[0].addr.mpz_export(rWritten.addr, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, rr)

    echo "  r GMP:               ", r.toHex()
    echo "  elapsed GMP:         ", elapsedGMP, " ns"

  echo &"\n  ratio Stint/Constantine: {float64(elapsedStint)/float64(elapsedCtt):.3f}x"
  echo &"  ratio GMP/Constantine: {float64(elapsedGMP)/float64(elapsedCtt):.3f}x"
  echo "---------------------------------------------------------"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant