Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove setrounding from functions #443

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
102 changes: 76 additions & 26 deletions src/intervals/arithmetic.jl
Expand Up @@ -235,21 +235,72 @@ end
//(a::Interval, b::Interval) = a / b # to deal with rationals


function min_ignore_nans(args...)
min(Iterators.filter(x->!isnan(x), args)...)
## fma: fused multiply-add
"""
hi, lo = directed_fma(a::T, b::T, c::T) where {T}

computes fma(a, b, c) rounded up (hi) and rounded down (lo)
"""
function directed_fma(a::T, b::T, c::T) where {T}

hi = fma(a, b, c)
isnan(hi) && return convert(T, -Inf), convert(T, Inf)
!isfinite(hi) && return hi, hi

hi, lo = two_fma(a, b, c)
if signbit(lo)
lo = prevfloat(hi)
elseif !iszero(lo)
lo = hi
hi = nextfloat(hi)
else
lo = hi
end
Comment on lines +251 to +258
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I follow the algorithm here, so references or details are appreciated, nor if lo and hi may be any arbitrary values.

Yet, I think it is worth to check first if lo is zero (instead of getting its sign) and then deal with the sign part. If you have lo=-0.0 (zero is signed in floating point!) this block yields lo = prevfloat(hi) (first condition is satisfied), whereas if you have lo=+0.0 it yields hi (the else is used). My point is that you get two different answers for the same "value" of lo, and I do know know which one is the correct one.

return hi, lo
end

function max_ignore_nans(args...)
max(Iterators.filter(x->!isnan(x), args)...)
"""
two_fma(a, b, c)

Computes `val = fl(fma(a, b, c))` and `err = fl(err(fma(a, b, c)))`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide a reference or more details, to understand why you do what you do?

"""
function two_fma(a::T, b::T, c::T) where {T}
val = fma(a, b, c)
val0, err0 = two_prod(a, b)
val1, err1 = two_sum(c, err0)
val2, err2 = two_sum(val0, val1)
err = ((val2 - val) + err2) + err1
return val, err
end

"""
two_sum(a, b)

Computes `val = fl(a+b)` and `err = err(a+b)`.
"""
@inline function two_sum(a::T, b::T) where {T}
val = a + b
v = val - a
err = (a - (val - v)) + (b - v)
return val, err
end

"""
two_prod(a, b)

Computes `val = fl(a*b)` and `err = fl(err(a*b))`.
"""
@inline function two_prod(a::T, b::T) where {T}
val = a * b
err = fma(a, b, -val)
val, err
end


## fma: fused multiply-add
function fma(a::Interval{T}, b::Interval{T}, c::Interval{T}) where T
#T = promote_type(eltype(a), eltype(b), eltype(c))

(isempty(a) || isempty(b) || isempty(c)) && return emptyinterval(T)
isnan(a+b+c) && return a + b + c

if isentire(a)
b == zero(b) && return c
Expand All @@ -259,28 +310,25 @@ function fma(a::Interval{T}, b::Interval{T}, c::Interval{T}) where T
a == zero(a) && return c
return entireinterval(T)

elseif isentire(c)
return entireinterval(T)
end

lo = setrounding(T, RoundDown) do
lo1 = fma(a.lo, b.lo, c.lo)
lo2 = fma(a.lo, b.hi, c.lo)
lo3 = fma(a.hi, b.lo, c.lo)
lo4 = fma(a.hi, b.hi, c.lo)
min_ignore_nans(lo1, lo2, lo3, lo4)
end
_, lo1 = directed_fma(a.lo, b.lo, c.lo)
_, lo2 = directed_fma(a.lo, b.hi, c.lo)
_, lo3 = directed_fma(a.hi, b.lo, c.lo)
_, lo4 = directed_fma(a.hi, b.hi, c.lo)
lo = min(lo1, lo2, lo3, lo4)

hi = setrounding(T, RoundUp) do
hi1 = fma(a.lo, b.lo, c.hi)
hi2 = fma(a.lo, b.hi, c.hi)
hi3 = fma(a.hi, b.lo, c.hi)
hi4 = fma(a.hi, b.hi, c.hi)
max_ignore_nans(hi1, hi2, hi3, hi4)
end
hi1, _ = directed_fma(a.lo, b.lo, c.hi)
hi2, _ = directed_fma(a.lo, b.hi, c.hi)
hi3, _ = directed_fma(a.hi, b.lo, c.hi)
hi4, _ = directed_fma(a.hi, b.hi, c.hi)
hi = max(hi1, hi2, hi3, hi4)

Interval(lo, hi)
return Interval(lo, hi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that, in the way you define lo and hi, the returned interval includes the true answer. However, I can't see if this is the tightest interval or not, and may be it isn't. Can you comment on this?

end


## Scalar functions on intervals (no directed rounding used)

function mag(a::Interval{T}) where T<:Real
Expand All @@ -291,13 +339,15 @@ function mag(a::Interval{T}) where T<:Real
max( abs(a.lo), abs(a.hi) )
end

"""
mig(a::Interval)

Returns the mignitude of an interval, defined as mig(X) = min {|x|: x ∈ X}
"""
function mig(a::Interval{T}) where T<:Real
isempty(a) && return convert(eltype(a), NaN)
zero(a.lo) ∈ a && return zero(a.lo)
r1, r2 = setrounding(T, RoundDown) do
abs(a.lo), abs(a.hi)
end
min( r1, r2 )
return min( abs(a.lo), abs(a.hi) )
end


Expand Down