diff --git a/src/intervals/arithmetic.jl b/src/intervals/arithmetic.jl index ed8360e5..a2545380 100644 --- a/src/intervals/arithmetic.jl +++ b/src/intervals/arithmetic.jl @@ -235,21 +235,72 @@ end //(a::Interval, b::Interval) = a / b # to deal with rationals -function min_ignore_nans(args...) - min(Iterators.filter(x->!isnan(x), args)...) +## fma: fused multiply-add +""" + hi, lo = directed_fma(a::T, b::T, c::T) where {T} + +computes fma(a, b, c) rounded up (hi) and rounded down (lo) +""" +function directed_fma(a::T, b::T, c::T) where {T} + + hi = fma(a, b, c) + isnan(hi) && return convert(T, -Inf), convert(T, Inf) + !isfinite(hi) && return hi, hi + + hi, lo = two_fma(a, b, c) + if signbit(lo) + lo = prevfloat(hi) + elseif !iszero(lo) + lo = hi + hi = nextfloat(hi) + else + lo = hi + end + return hi, lo end -function max_ignore_nans(args...) - max(Iterators.filter(x->!isnan(x), args)...) +""" + two_fma(a, b, c) + +Computes `val = fl(fma(a, b, c))` and `err = fl(err(fma(a, b, c)))`. +""" +function two_fma(a::T, b::T, c::T) where {T} + val = fma(a, b, c) + val0, err0 = two_prod(a, b) + val1, err1 = two_sum(c, err0) + val2, err2 = two_sum(val0, val1) + err = ((val2 - val) + err2) + err1 + return val, err end +""" + two_sum(a, b) + +Computes `val = fl(a+b)` and `err = err(a+b)`. +""" +@inline function two_sum(a::T, b::T) where {T} + val = a + b + v = val - a + err = (a - (val - v)) + (b - v) + return val, err +end + +""" + two_prod(a, b) + +Computes `val = fl(a*b)` and `err = fl(err(a*b))`. +""" +@inline function two_prod(a::T, b::T) where {T} + val = a * b + err = fma(a, b, -val) + val, err +end -## fma: fused multiply-add function fma(a::Interval{T}, b::Interval{T}, c::Interval{T}) where T - #T = promote_type(eltype(a), eltype(b), eltype(c)) (isempty(a) || isempty(b) || isempty(c)) && return emptyinterval(T) + isnan(a+b+c) && return a + b + c if isentire(a) b == zero(b) && return c @@ -259,28 +310,25 @@ function fma(a::Interval{T}, b::Interval{T}, c::Interval{T}) where T a == zero(a) && return c return entireinterval(T) + elseif isentire(c) + return entireinterval(T) end - lo = setrounding(T, RoundDown) do - lo1 = fma(a.lo, b.lo, c.lo) - lo2 = fma(a.lo, b.hi, c.lo) - lo3 = fma(a.hi, b.lo, c.lo) - lo4 = fma(a.hi, b.hi, c.lo) - min_ignore_nans(lo1, lo2, lo3, lo4) - end + _, lo1 = directed_fma(a.lo, b.lo, c.lo) + _, lo2 = directed_fma(a.lo, b.hi, c.lo) + _, lo3 = directed_fma(a.hi, b.lo, c.lo) + _, lo4 = directed_fma(a.hi, b.hi, c.lo) + lo = min(lo1, lo2, lo3, lo4) - hi = setrounding(T, RoundUp) do - hi1 = fma(a.lo, b.lo, c.hi) - hi2 = fma(a.lo, b.hi, c.hi) - hi3 = fma(a.hi, b.lo, c.hi) - hi4 = fma(a.hi, b.hi, c.hi) - max_ignore_nans(hi1, hi2, hi3, hi4) - end + hi1, _ = directed_fma(a.lo, b.lo, c.hi) + hi2, _ = directed_fma(a.lo, b.hi, c.hi) + hi3, _ = directed_fma(a.hi, b.lo, c.hi) + hi4, _ = directed_fma(a.hi, b.hi, c.hi) + hi = max(hi1, hi2, hi3, hi4) - Interval(lo, hi) + return Interval(lo, hi) end - ## Scalar functions on intervals (no directed rounding used) function mag(a::Interval{T}) where T<:Real @@ -291,13 +339,15 @@ function mag(a::Interval{T}) where T<:Real max( abs(a.lo), abs(a.hi) ) end +""" + mig(a::Interval) + +Returns the mignitude of an interval, defined as mig(X) = min {|x|: x ∈ X} +""" function mig(a::Interval{T}) where T<:Real isempty(a) && return convert(eltype(a), NaN) zero(a.lo) ∈ a && return zero(a.lo) - r1, r2 = setrounding(T, RoundDown) do - abs(a.lo), abs(a.hi) - end - min( r1, r2 ) + return min( abs(a.lo), abs(a.hi) ) end