From 59a7d11befc1fc727124d673dd1ed63e9621d456 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 16 Aug 2023 18:56:46 -0500 Subject: [PATCH] Scale down arguments and scale back the result This improves accuracy at extremes of supported range. --- .../kernels/elementwise_functions/sqrt.hpp | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index f72b5d963b..90ca525a54 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -148,16 +148,26 @@ template struct SqrtFunctor constexpr realT half = realT(0x1.0p-1f); // 1/2 constexpr realT zero = realT(0); - if (std::signbit(x)) { - realT m = std::hypot(x, y); - realT d = std::sqrt((m - x) * half); - return {(d == zero ? zero : std::abs(y) / d * half), - std::copysign(d, y)}; + const int exp_x = std::ilogb(x); + const int exp_y = std::ilogb(y); + + int sc = std::max(exp_x, exp_y) / 2; + const realT xx = std::ldexp(x, -sc * 2); + const realT yy = std::ldexp(y, -sc * 2); + + if (std::signbit(xx)) { + const realT m = std::hypot(xx, yy); + const realT d = std::sqrt((m - xx) * half); + const realT res_re = (d == zero ? zero : std::abs(yy) / d * half); + const realT res_im = std::copysign(d, yy); + return {std::ldexp(res_re, sc), std::ldexp(res_im, sc)}; } else { - realT m = std::hypot(x, y); - realT d = std::sqrt((m + x) * half); - return {d, (d == zero) ? std::copysign(zero, y) : y * half / d}; + const realT m = std::hypot(xx, yy); + const realT d = std::sqrt((m + xx) * half); + const realT res_im = + (d == zero) ? std::copysign(zero, yy) : yy * half / d; + return {std::ldexp(d, sc), std::ldexp(res_im, sc)}; } } };