From be7310e49562c9cd4c706e2ed77164ba28a05ebb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 12 Mar 2024 17:53:24 +0100 Subject: [PATCH] gh-73468: Add math.fma() function Added new math.fma() function, wrapping C99's ``fma()`` operation: fused multiply-add function. Co-Authored-By: Mark Dickinson --- Doc/library/math.rst | 16 ++ Doc/whatsnew/3.13.rst | 10 + Lib/test/test_math.py | 238 ++++++++++++++++++ ...4-03-12-17-53-14.gh-issue-73468.z4ZzvJ.rst | 2 + Modules/clinic/mathmodule.c.h | 63 ++++- Modules/mathmodule.c | 43 ++++ 6 files changed, 371 insertions(+), 1 deletion(-) create mode 100644 Misc/NEWS.d/next/Library/2024-03-12-17-53-14.gh-issue-73468.z4ZzvJ.rst diff --git a/Doc/library/math.rst b/Doc/library/math.rst index 93755be717e2ef..1475d26486de5f 100644 --- a/Doc/library/math.rst +++ b/Doc/library/math.rst @@ -82,6 +82,22 @@ Number-theoretic and representation functions should return an :class:`~numbers.Integral` value. +.. function:: fma(x, y, z) + + Fused multiply-add operation. Return ``(x * y) + z``, computed as though with + infinite precision and range followed by a single round to the ``float`` + format. This operation often provides better accuracy than the direct + expression ``(x * y) + z``. + + This function follows the specification of the fusedMultiplyAdd operation + described in the IEEE 754 standard. The standard leaves one case + implementation-defined, namely the result of ``fma(0, inf, nan)`` + and ``fma(inf, 0, nan)``. In these cases, ``math.fma`` returns a NaN, + and does not raise any exception. + + .. versionadded:: 3.13 + + .. function:: fmod(x, y) Return ``fmod(x, y)``, as defined by the platform C library. Note that the diff --git a/Doc/whatsnew/3.13.rst b/Doc/whatsnew/3.13.rst index 03ae6018905380..78b5868b6549c6 100644 --- a/Doc/whatsnew/3.13.rst +++ b/Doc/whatsnew/3.13.rst @@ -383,6 +383,16 @@ marshal code objects which are incompatible between Python versions. (Contributed by Serhiy Storchaka in :gh:`113626`.) +math +---- + +A new function :func:`~math.fma` for fused multiply-add operations has been +added. This function computes ``x * y + z`` with only a single round, and so +avoids any intermediate loss of precision. It wraps the ``fma()`` function +provided by C99, and follows the specification of the IEEE 754 +"fusedMultiplyAdd" operation for special cases. +(Contributed by Mark Dickinson and Victor Stinner in :gh:`73468`.) + mmap ---- diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index ad382fc2b59891..aaa3b16d33fb7d 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -2613,6 +2613,244 @@ def test_fractions(self): self.assertAllNotClose(fraction_examples, rel_tol=1e-9) +class FMATests(unittest.TestCase): + """ Tests for math.fma. """ + + def test_fma_nan_results(self): + # Selected representative values. + values = [ + -math.inf, -1e300, -2.3, -1e-300, -0.0, + 0.0, 1e-300, 2.3, 1e300, math.inf, math.nan + ] + + # If any input is a NaN, the result should be a NaN, too. + for a, b in itertools.product(values, repeat=2): + self.assertIsNaN(math.fma(math.nan, a, b)) + self.assertIsNaN(math.fma(a, math.nan, b)) + self.assertIsNaN(math.fma(a, b, math.nan)) + + def test_fma_infinities(self): + # Cases involving infinite inputs or results. + positives = [1e-300, 2.3, 1e300, math.inf] + finites = [-1e300, -2.3, -1e-300, -0.0, 0.0, 1e-300, 2.3, 1e300] + non_nans = [-math.inf, -2.3, -0.0, 0.0, 2.3, math.inf] + + # ValueError due to inf * 0 computation. + for c in non_nans: + for infinity in [math.inf, -math.inf]: + for zero in [0.0, -0.0]: + with self.assertRaises(ValueError): + math.fma(infinity, zero, c) + with self.assertRaises(ValueError): + math.fma(zero, infinity, c) + + # ValueError when a*b and c both infinite of opposite signs. + for b in positives: + with self.assertRaises(ValueError): + math.fma(math.inf, b, -math.inf) + with self.assertRaises(ValueError): + math.fma(math.inf, -b, math.inf) + with self.assertRaises(ValueError): + math.fma(-math.inf, -b, -math.inf) + with self.assertRaises(ValueError): + math.fma(-math.inf, b, math.inf) + with self.assertRaises(ValueError): + math.fma(b, math.inf, -math.inf) + with self.assertRaises(ValueError): + math.fma(-b, math.inf, math.inf) + with self.assertRaises(ValueError): + math.fma(-b, -math.inf, -math.inf) + with self.assertRaises(ValueError): + math.fma(b, -math.inf, math.inf) + + # Infinite result when a*b and c both infinite of the same sign. + for b in positives: + self.assertEqual(math.fma(math.inf, b, math.inf), math.inf) + self.assertEqual(math.fma(math.inf, -b, -math.inf), -math.inf) + self.assertEqual(math.fma(-math.inf, -b, math.inf), math.inf) + self.assertEqual(math.fma(-math.inf, b, -math.inf), -math.inf) + self.assertEqual(math.fma(b, math.inf, math.inf), math.inf) + self.assertEqual(math.fma(-b, math.inf, -math.inf), -math.inf) + self.assertEqual(math.fma(-b, -math.inf, math.inf), math.inf) + self.assertEqual(math.fma(b, -math.inf, -math.inf), -math.inf) + + # Infinite result when a*b finite, c infinite. + for a, b in itertools.product(finites, finites): + self.assertEqual(math.fma(a, b, math.inf), math.inf) + self.assertEqual(math.fma(a, b, -math.inf), -math.inf) + + # Infinite result when a*b infinite, c finite. + for b, c in itertools.product(positives, finites): + self.assertEqual(math.fma(math.inf, b, c), math.inf) + self.assertEqual(math.fma(-math.inf, b, c), -math.inf) + self.assertEqual(math.fma(-math.inf, -b, c), math.inf) + self.assertEqual(math.fma(math.inf, -b, c), -math.inf) + + self.assertEqual(math.fma(b, math.inf, c), math.inf) + self.assertEqual(math.fma(b, -math.inf, c), -math.inf) + self.assertEqual(math.fma(-b, -math.inf, c), math.inf) + self.assertEqual(math.fma(-b, math.inf, c), -math.inf) + + # gh-73468: On WASI and FreeBSD, libc fma() doesn't implement IEE 754-2008 + # properly: it doesn't use the right sign when the result is zero. + @unittest.skipIf(support.is_wasi, + "WASI fma() doesn't implement IEE 754-2008 properly") + @unittest.skipIf(sys.platform.startswith('freebsd'), + "FreeBSD fma() doesn't implement IEE 754-2008 properly") + def test_fma_zero_result(self): + nonnegative_finites = [0.0, 1e-300, 2.3, 1e300] + + # Zero results from exact zero inputs. + for b in nonnegative_finites: + self.assertIsPositiveZero(math.fma(0.0, b, 0.0)) + self.assertIsPositiveZero(math.fma(0.0, b, -0.0)) + self.assertIsNegativeZero(math.fma(0.0, -b, -0.0)) + self.assertIsPositiveZero(math.fma(0.0, -b, 0.0)) + self.assertIsPositiveZero(math.fma(-0.0, -b, 0.0)) + self.assertIsPositiveZero(math.fma(-0.0, -b, -0.0)) + self.assertIsNegativeZero(math.fma(-0.0, b, -0.0)) + self.assertIsPositiveZero(math.fma(-0.0, b, 0.0)) + + self.assertIsPositiveZero(math.fma(b, 0.0, 0.0)) + self.assertIsPositiveZero(math.fma(b, 0.0, -0.0)) + self.assertIsNegativeZero(math.fma(-b, 0.0, -0.0)) + self.assertIsPositiveZero(math.fma(-b, 0.0, 0.0)) + self.assertIsPositiveZero(math.fma(-b, -0.0, 0.0)) + self.assertIsPositiveZero(math.fma(-b, -0.0, -0.0)) + self.assertIsNegativeZero(math.fma(b, -0.0, -0.0)) + self.assertIsPositiveZero(math.fma(b, -0.0, 0.0)) + + # Exact zero result from nonzero inputs. + self.assertIsPositiveZero(math.fma(2.0, 2.0, -4.0)) + self.assertIsPositiveZero(math.fma(2.0, -2.0, 4.0)) + self.assertIsPositiveZero(math.fma(-2.0, -2.0, -4.0)) + self.assertIsPositiveZero(math.fma(-2.0, 2.0, 4.0)) + + # Underflow to zero. + tiny = 1e-300 + self.assertIsPositiveZero(math.fma(tiny, tiny, 0.0)) + self.assertIsNegativeZero(math.fma(tiny, -tiny, 0.0)) + self.assertIsPositiveZero(math.fma(-tiny, -tiny, 0.0)) + self.assertIsNegativeZero(math.fma(-tiny, tiny, 0.0)) + self.assertIsPositiveZero(math.fma(tiny, tiny, -0.0)) + self.assertIsNegativeZero(math.fma(tiny, -tiny, -0.0)) + self.assertIsPositiveZero(math.fma(-tiny, -tiny, -0.0)) + self.assertIsNegativeZero(math.fma(-tiny, tiny, -0.0)) + + # Corner case where rounding the multiplication would + # give the wrong result. + x = float.fromhex('0x1p-500') + y = float.fromhex('0x1p-550') + z = float.fromhex('0x1p-1000') + self.assertIsNegativeZero(math.fma(x-y, x+y, -z)) + self.assertIsPositiveZero(math.fma(y-x, x+y, z)) + self.assertIsNegativeZero(math.fma(y-x, -(x+y), -z)) + self.assertIsPositiveZero(math.fma(x-y, -(x+y), z)) + + def test_fma_overflow(self): + a = b = float.fromhex('0x1p512') + c = float.fromhex('0x1p1023') + # Overflow from multiplication. + with self.assertRaises(OverflowError): + math.fma(a, b, 0.0) + self.assertEqual(math.fma(a, b/2.0, 0.0), c) + # Overflow from the addition. + with self.assertRaises(OverflowError): + math.fma(a, b/2.0, c) + # No overflow, even though a*b overflows a float. + self.assertEqual(math.fma(a, b, -c), c) + + # Extreme case: a * b is exactly at the overflow boundary, so the + # tiniest offset makes a difference between overflow and a finite + # result. + a = float.fromhex('0x1.ffffffc000000p+511') + b = float.fromhex('0x1.0000002000000p+512') + c = float.fromhex('0x0.0000000000001p-1022') + with self.assertRaises(OverflowError): + math.fma(a, b, 0.0) + with self.assertRaises(OverflowError): + math.fma(a, b, c) + self.assertEqual(math.fma(a, b, -c), + float.fromhex('0x1.fffffffffffffp+1023')) + + # Another extreme case: here a*b is about as large as possible subject + # to math.fma(a, b, c) being finite. + a = float.fromhex('0x1.ae565943785f9p+512') + b = float.fromhex('0x1.3094665de9db8p+512') + c = float.fromhex('0x1.fffffffffffffp+1023') + self.assertEqual(math.fma(a, b, -c), c) + + def test_fma_single_round(self): + a = float.fromhex('0x1p-50') + self.assertEqual(math.fma(a - 1.0, a + 1.0, 1.0), a*a) + + def test_random(self): + # A collection of randomly generated inputs for which the naive FMA + # (with two rounds) gives a different result from a singly-rounded FMA. + + # tuples (a, b, c, expected) + test_values = [ + ('0x1.694adde428b44p-1', '0x1.371b0d64caed7p-1', + '0x1.f347e7b8deab8p-4', '0x1.19f10da56c8adp-1'), + ('0x1.605401ccc6ad6p-2', '0x1.ce3a40bf56640p-2', + '0x1.96e3bf7bf2e20p-2', '0x1.1af6d8aa83101p-1'), + ('0x1.e5abd653a67d4p-2', '0x1.a2e400209b3e6p-1', + '0x1.a90051422ce13p-1', '0x1.37d68cc8c0fbbp+0'), + ('0x1.f94e8efd54700p-2', '0x1.123065c812cebp-1', + '0x1.458f86fb6ccd0p-1', '0x1.ccdcee26a3ff3p-1'), + ('0x1.bd926f1eedc96p-1', '0x1.eee9ca68c5740p-1', + '0x1.960c703eb3298p-2', '0x1.3cdcfb4fdb007p+0'), + ('0x1.27348350fbccdp-1', '0x1.3b073914a53f1p-1', + '0x1.e300da5c2b4cbp-1', '0x1.4c51e9a3c4e29p+0'), + ('0x1.2774f00b3497bp-1', '0x1.7038ec336bff0p-2', + '0x1.2f6f2ccc3576bp-1', '0x1.99ad9f9c2688bp-1'), + ('0x1.51d5a99300e5cp-1', '0x1.5cd74abd445a1p-1', + '0x1.8880ab0bbe530p-1', '0x1.3756f96b91129p+0'), + ('0x1.73cb965b821b8p-2', '0x1.218fd3d8d5371p-1', + '0x1.d1ea966a1f758p-2', '0x1.5217b8fd90119p-1'), + ('0x1.4aa98e890b046p-1', '0x1.954d85dff1041p-1', + '0x1.122b59317ebdfp-1', '0x1.0bf644b340cc5p+0'), + ('0x1.e28f29e44750fp-1', '0x1.4bcc4fdcd18fep-1', + '0x1.fd47f81298259p-1', '0x1.9b000afbc9995p+0'), + ('0x1.d2e850717fe78p-3', '0x1.1dd7531c303afp-1', + '0x1.e0869746a2fc2p-2', '0x1.316df6eb26439p-1'), + ('0x1.cf89c75ee6fbap-2', '0x1.b23decdc66825p-1', + '0x1.3d1fe76ac6168p-1', '0x1.00d8ea4c12abbp+0'), + ('0x1.3265ae6f05572p-2', '0x1.16d7ec285f7a2p-1', + '0x1.0b8405b3827fbp-1', '0x1.5ef33c118a001p-1'), + ('0x1.c4d1bf55ec1a5p-1', '0x1.bc59618459e12p-2', + '0x1.ce5b73dc1773dp-1', '0x1.496cf6164f99bp+0'), + ('0x1.d350026ac3946p-1', '0x1.9a234e149a68cp-2', + '0x1.f5467b1911fd6p-2', '0x1.b5cee3225caa5p-1'), + ] + for a_hex, b_hex, c_hex, expected_hex in test_values: + a = float.fromhex(a_hex) + b = float.fromhex(b_hex) + c = float.fromhex(c_hex) + expected = float.fromhex(expected_hex) + self.assertEqual(math.fma(a, b, c), expected) + self.assertEqual(math.fma(b, a, c), expected) + + # Custom assertions. + def assertIsNaN(self, value): + self.assertTrue( + math.isnan(value), + msg="Expected a NaN, got {!r}".format(value) + ) + + def assertIsPositiveZero(self, value): + self.assertTrue( + value == 0 and math.copysign(1, value) > 0, + msg="Expected a positive zero, got {!r}".format(value) + ) + + def assertIsNegativeZero(self, value): + self.assertTrue( + value == 0 and math.copysign(1, value) < 0, + msg="Expected a negative zero, got {!r}".format(value) + ) + + def load_tests(loader, tests, pattern): from doctest import DocFileSuite tests.addTest(DocFileSuite(os.path.join("mathdata", "ieee754.txt"))) diff --git a/Misc/NEWS.d/next/Library/2024-03-12-17-53-14.gh-issue-73468.z4ZzvJ.rst b/Misc/NEWS.d/next/Library/2024-03-12-17-53-14.gh-issue-73468.z4ZzvJ.rst new file mode 100644 index 00000000000000..c91f4eb97e06bc --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-03-12-17-53-14.gh-issue-73468.z4ZzvJ.rst @@ -0,0 +1,2 @@ +Added new :func:`math.fma` function, wrapping C99's ``fma()`` operation: +fused multiply-add function. Patch by Mark Dickinson and Victor Stinner. diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h index ca14c03f16f706..666b6b3790dae5 100644 --- a/Modules/clinic/mathmodule.c.h +++ b/Modules/clinic/mathmodule.c.h @@ -204,6 +204,67 @@ PyDoc_STRVAR(math_log10__doc__, #define MATH_LOG10_METHODDEF \ {"log10", (PyCFunction)math_log10, METH_O, math_log10__doc__}, +PyDoc_STRVAR(math_fma__doc__, +"fma($module, x, y, z, /)\n" +"--\n" +"\n" +"Fused multiply-add operation.\n" +"\n" +"Compute (x * y) + z with a single round."); + +#define MATH_FMA_METHODDEF \ + {"fma", _PyCFunction_CAST(math_fma), METH_FASTCALL, math_fma__doc__}, + +static PyObject * +math_fma_impl(PyObject *module, double x, double y, double z); + +static PyObject * +math_fma(PyObject *module, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + double x; + double y; + double z; + + if (!_PyArg_CheckPositional("fma", nargs, 3, 3)) { + goto exit; + } + if (PyFloat_CheckExact(args[0])) { + x = PyFloat_AS_DOUBLE(args[0]); + } + else + { + x = PyFloat_AsDouble(args[0]); + if (x == -1.0 && PyErr_Occurred()) { + goto exit; + } + } + if (PyFloat_CheckExact(args[1])) { + y = PyFloat_AS_DOUBLE(args[1]); + } + else + { + y = PyFloat_AsDouble(args[1]); + if (y == -1.0 && PyErr_Occurred()) { + goto exit; + } + } + if (PyFloat_CheckExact(args[2])) { + z = PyFloat_AS_DOUBLE(args[2]); + } + else + { + z = PyFloat_AsDouble(args[2]); + if (z == -1.0 && PyErr_Occurred()) { + goto exit; + } + } + return_value = math_fma_impl(module, x, y, z); + +exit: + return return_value; +} + PyDoc_STRVAR(math_fmod__doc__, "fmod($module, x, y, /)\n" "--\n" @@ -950,4 +1011,4 @@ math_ulp(PyObject *module, PyObject *arg) exit: return return_value; } -/*[clinic end generated code: output=6b2eeaed8d8a76d5 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=9fe3f007f474e015 input=a9049054013a1b77]*/ diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index a877bfcd6afb68..8ba0431f4a47b7 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -2321,6 +2321,48 @@ math_log10(PyObject *module, PyObject *x) } +/*[clinic input] +math.fma + + x: double + y: double + z: double + / + +Fused multiply-add operation. + +Compute (x * y) + z with a single round. +[clinic start generated code]*/ + +static PyObject * +math_fma_impl(PyObject *module, double x, double y, double z) +/*[clinic end generated code: output=4fc8626dbc278d17 input=e3ad1f4a4c89626e]*/ +{ + double r = fma(x, y, z); + + /* Fast path: if we got a finite result, we're done. */ + if (Py_IS_FINITE(r)) { + return PyFloat_FromDouble(r); + } + + /* Non-finite result. Raise an exception if appropriate, else return r. */ + if (Py_IS_NAN(r)) { + if (!Py_IS_NAN(x) && !Py_IS_NAN(y) && !Py_IS_NAN(z)) { + /* NaN result from non-NaN inputs. */ + PyErr_SetString(PyExc_ValueError, "invalid operation in fma"); + return NULL; + } + } + else if (Py_IS_FINITE(x) && Py_IS_FINITE(y) && Py_IS_FINITE(z)) { + /* Infinite result from finite inputs. */ + PyErr_SetString(PyExc_OverflowError, "overflow in fma"); + return NULL; + } + + return PyFloat_FromDouble(r); +} + + /*[clinic input] math.fmod @@ -4094,6 +4136,7 @@ static PyMethodDef math_methods[] = { {"fabs", math_fabs, METH_O, math_fabs_doc}, MATH_FACTORIAL_METHODDEF MATH_FLOOR_METHODDEF + MATH_FMA_METHODDEF MATH_FMOD_METHODDEF MATH_FREXP_METHODDEF MATH_FSUM_METHODDEF