Skip to content

Commit

Permalink
various fixes to array conversions (#668)
Browse files Browse the repository at this point in the history
* 0-dimensional PyArray

* gc safety and CartesianIndex support in PyArray

* PyArray bounds checking

* rm redundant methods

* another simplification

* no sum for empty tuples

* fix docstring

* consolidation of Array{PyObject} conversion, add a missing GC root, pysequence check fix
  • Loading branch information
stevengj committed Mar 22, 2019
1 parent 0cb5c45 commit 33f07eb
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 109 deletions.
16 changes: 12 additions & 4 deletions src/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ function pyarray_dims(o::PyObject, forcelist=true)
return () # too many non-List types can pretend to be sequences
end
len = ccall((@pysym :PySequence_Size), Int, (PyPtr,), o)
len < 0 && error("not a PySequence object")
if len == 0
return (0,)
end
Expand All @@ -392,12 +393,18 @@ function pyarray_dims(o::PyObject, forcelist=true)
end

function py2array(T, o::PyObject)
dims = pyarray_dims(o)
b = PyBuffer()
if isbuftype!(o, b)
dims = size(b)
else
dims = pyarray_dims(o)
end
pydecref(b) # safe for immediate release
A = Array{pyany_toany(T)}(undef, dims)
py2array(T, A, o, 1, 1)
py2array(T, A, o, 1, 1) # fixme: faster conversion for supported buffer types?
end

function convert(::Type{Vector{T}}, o::PyObject) where T
function py2vector(T, o::PyObject)
len = ccall((@pysym :PySequence_Size), Int, (PyPtr,), o)
if len < 0 || # not a sequence
len+1 < 0 # object pretending to be a sequence of infinite length
Expand All @@ -406,6 +413,7 @@ function convert(::Type{Vector{T}}, o::PyObject) where T
end
py2array(T, Array{pyany_toany(T)}(undef, len), o, 1, 1)
end
convert(::Type{Vector{T}}, o::PyObject) where T = py2vector(T, o)

convert(::Type{Array}, o::PyObject) = map(identity, py2array(PyAny, o))
convert(::Type{Array{T}}, o::PyObject) where {T} = py2array(T, o)
Expand Down Expand Up @@ -800,8 +808,8 @@ function pytype_query(o::PyObject, default::TypeTuple=PyObject)
@return_not_None pyfunction_query(o)
@return_not_None pydate_query(o)
@return_not_None pydict_query(o)
@return_not_None pysequence_query(o)
@return_not_None pyptr_query(o)
@return_not_None pysequence_query(o)
@return_not_None pynothing_query(o)
@return_not_None pymp_query(o)
for (py,jl) in pytype_queries
Expand Down
138 changes: 61 additions & 77 deletions src/pyarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ identical memory layout to a Julia `Array` of the same size.
`st` should be the stride(s) *in bytes* between elements in each dimension
"""
function f_contiguous(::Type{T}, sz::NTuple{N,Int}, st::NTuple{N,Int}) where {T,N}
N == 0 && return true # 0-dimensional arrays have 1 element, always contiguous
if st[1] != sizeof(T)
# not contiguous
return false
Expand Down Expand Up @@ -153,77 +154,72 @@ function copy(a::PyArray{T,N}) where {T,N}
return A
end

# TODO: need to do bounds-checking of these indices!
# TODO: need to GC root these `a`s to guard against the PyArray getting gc'd,
# e.g. if it's a temporary in a function:
# `two_rands() = pycall(np.rand, PyArray, 10)[1:2]`


getindex(a::PyArray{T,0}) where {T} = unsafe_load(a.data)
getindex(a::PyArray{T,1}, i::Integer) where {T} = unsafe_load(a.data, 1 + (i-1)*a.st[1])
unsafe_data_load(a::PyArray, i::Integer) = GC.@preserve a unsafe_load(a.data, i)

@inline data_index(a::PyArray{<:Any,N}, i::CartesianIndex{N}) where {N} =
1 + sum(ntuple(dim -> (i[dim]-1) * a.st[dim], Val{N}())) # Val lets julia unroll/inline
data_index(a::PyArray{<:Any,0}, i::CartesianIndex{0}) = 1

# handle passing fewer/more indices than dimensions by canonicalizing to M==N
@inline function fixindex(a::PyArray{<:Any,N}, i::CartesianIndex{M}) where {M,N}
if M == N
return i
elseif M < N
@boundscheck(all(ntuple(k -> size(a,k+M)==1, Val{N-M}())) ||
throw(BoundsError(a, i))) # trailing sizes must == 1
return CartesianIndex(Tuple(i)..., ntuple(k -> 1, Val{N-M}())...)
else # M > N
@boundscheck(all(ntuple(k -> i[k+N]==1, Val{M-N}())) ||
throw(BoundsError(a, i))) # trailing indices must == 1
return CartesianIndex(ntuple(k -> i[k], Val{N}()))
end
end

getindex(a::PyArray{T,2}, i::Integer, j::Integer) where {T} =
unsafe_load(a.data, 1 + (i-1)*a.st[1] + (j-1)*a.st[2])
@inline function getindex(a::PyArray, i::CartesianIndex)
j = fixindex(a, i)
@boundscheck checkbounds(a, j)
unsafe_data_load(a, data_index(a, j))
end
@inline getindex(a::PyArray, i::Integer...) = a[CartesianIndex(i)]
@inline getindex(a::PyArray{<:Any,1}, i::Integer) = a[CartesianIndex(i)]

# linear indexing
function getindex(a::PyArray, i::Integer)
@boundscheck checkbounds(a, i)
if a.f_contig
return unsafe_load(a.data, i)
return unsafe_data_load(a, i)
else
return a[ind2sub(a.dims, i)...]
@inbounds return a[CartesianIndices(a)[i]]
end
end

function getindex(a::PyArray, is::Integer...)
index = 1
n = min(length(is),length(a.st))
for i = 1:n
index += (is[i]-1)*a.st[i]
end
for i = n+1:length(is)
if is[i] != 1
throw(BoundsError())
end
end
unsafe_load(a.data, index)
end

function writeok_assign(a::PyArray, v, i::Integer)
if a.info.readonly
throw(ArgumentError("read-only PyArray"))
else
unsafe_store!(a.data, v, i)
GC.@preserve a unsafe_store!(a.data, v, i)
end
return a
return v
end

setindex!(a::PyArray{T,0}, v) where {T} = writeok_assign(a, v, 1)
setindex!(a::PyArray{T,1}, v, i::Integer) where {T} = writeok_assign(a, v, 1 + (i-1)*a.st[1])

setindex!(a::PyArray{T,2}, v, i::Integer, j::Integer) where {T} =
writeok_assign(a, v, 1 + (i-1)*a.st[1] + (j-1)*a.st[2])
@inline function setindex!(a::PyArray, v, i::CartesianIndex)
j = fixindex(a, i)
@boundscheck checkbounds(a, j)
writeok_assign(a, v, data_index(a, j))
end
@inline setindex!(a::PyArray, v, i::Integer...) = setindex!(a, v, CartesianIndex(i))
@inline setindex!(a::PyArray{<:Any,1}, v, i::Integer) = setindex!(a, v, CartesianIndex(i))

# linear indexing
function setindex!(a::PyArray, v, i::Integer)
@boundscheck checkbounds(a, i)
if a.f_contig
return writeok_assign(a, v, i)
else
return setindex!(a, v, ind2sub(a.dims, i)...)
@inbounds return setindex!(a, v, CartesianIndices(a)[i])
end
end

function setindex!(a::PyArray, v, is::Integer...)
index = 1
n = min(length(is),length(a.st))
for i = 1:n
index += (is[i]-1)*a.st[i]
end
for i = n+1:length(is)
if is[i] != 1
throw(BoundsError())
end
end
writeok_assign(a, v, index)
end

stride(a::PyArray, i::Integer) = a.st[i]

Base.unsafe_convert(::Type{Ptr{T}}, a::PyArray{T}) where {T} = a.data
Expand All @@ -244,68 +240,56 @@ summary(a::PyArray{T}) where {T} = string(Base.dims2string(size(a)), " ",
#########################################################################
# PyArray <-> PyObject conversions

const PYARR_TYPES = Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Float16,Float32,Float64,ComplexF32,ComplexF64,PyPtr}
const PYARR_TYPES = Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Float16,Float32,Float64,ComplexF32,ComplexF64,PyPtr,PyObject}

PyObject(a::PyArray) = a.o

convert(::Type{PyArray}, o::PyObject) = PyArray(o)

# PyObject arrays are created by taking a NumPy array of PyPtr and converting
pyo2ptr(T::Type) = T
pyo2ptr(::Type{PyObject}) = PyPtr
pyocopy(a) = copy(a)
pyocopy(a::AbstractArray{PyPtr}) = GC.@preserve a map(pyincref, a)

function convert(::Type{Array{T, 1}}, o::PyObject) where T<:PYARR_TYPES
try
copy(PyArray{T, 1}(o, PyArray_Info(o))) # will check T and N vs. info
return pyocopy(PyArray{pyo2ptr(T), 1}(o, PyArray_Info(o))) # will check T and N vs. info
catch
len = @pycheckz ccall((@pysym :PySequence_Size), Int, (PyPtr,), o)
A = Array{pyany_toany(T)}(undef, len)
py2array(T, A, o, 1, 1)
return py2vector(T, o)
end
end

function convert(::Type{Array{T}}, o::PyObject) where T<:PYARR_TYPES
try
info = PyArray_Info(o)
try
copy(PyArray{T, length(info.sz)}(o, info)) # will check T == eltype(info)
return pyocopy(PyArray{pyo2ptr(T), length(info.sz)}(o, info)) # will check T == eltype(info)
catch
return py2array(T, Array{pyany_toany(T)}(undef, info.sz...), o, 1, 1)
return py2array(T, Array{T}(undef, info.sz...), o, 1, 1)
end
catch
py2array(T, o)
return py2array(T, o)
end
end

function convert(::Type{Array{T,N}}, o::PyObject) where {T<:PYARR_TYPES,N}
try
info = PyArray_Info(o)
try
copy(PyArray{T,N}(o, info)) # will check T,N == eltype(info),ndims(info)
pyocopy(PyArray{pyo2ptr(T),N}(o, info)) # will check T,N == eltype(info),ndims(info)
catch
nd = length(info.sz)
if nd != N
throw(ArgumentError("cannot convert $(nd)d array to $(N)d"))
end
return py2array(T, Array{pyany_toany(T)}(undef, info.sz...), o, 1, 1)
nd == N || throw(ArgumentError("cannot convert $(nd)d array to $(N)d"))
return py2array(T, Array{T}(undef, info.sz...), o, 1, 1)
end
catch
A = py2array(T, o)
if ndims(A) != N
throw(ArgumentError("cannot convert $(ndims(A))d array to $(N)d"))
end
A
ndims(A) == N || throw(ArgumentError("cannot convert $(ndims(A))d array to $(N)d"))
return A
end
end

function convert(::Type{Array{PyObject}}, o::PyObject)
map(pyincref, convert(Array{PyPtr}, o))
end

function convert(::Type{Array{PyObject,1}}, o::PyObject)
map(pyincref, convert(Array{PyPtr, 1}, o))
end

function convert(::Type{Array{PyObject,N}}, o::PyObject) where N
map(pyincref, convert(Array{PyPtr, N}, o))
end

array_format(o::PyObject) = array_format(PyBuffer(o, PyBUF_ND_STRIDED))

"""
Expand Down
36 changes: 17 additions & 19 deletions src/pybuffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ the python c-api function `PyObject_GetBuffer()`, unless o.obj is a PyPtr(C_NULL
function pydecref(o::PyBuffer)
# note that PyBuffer_Release sets o.obj to NULL, and
# is a no-op if o.obj is already NULL
# TODO change to `Ref{PyBuffer}` when 0.6 is dropped.
_finalized[] || ccall(@pysym(:PyBuffer_Release), Cvoid, (Any,), o)
_finalized[] || ccall(@pysym(:PyBuffer_Release), Cvoid, (Ref{PyBuffer},), o)
o
end

Expand Down Expand Up @@ -96,10 +95,9 @@ end
# Strides in bytes
Base.strides(b::PyBuffer) = ((stride(b,i) for i in 1:b.buf.ndim)...,)

# TODO change to `Ref{PyBuffer}` when 0.6 is dropped.
iscontiguous(b::PyBuffer) =
1 == ccall((@pysym :PyBuffer_IsContiguous), Cint,
(Any, Cchar), b, 'A')
(Ref{PyBuffer}, Cchar), b, 'A')

#############################################################################
# pybuffer constant values from Include/object.h
Expand All @@ -122,35 +120,33 @@ function PyBuffer(o::Union{PyObject,PyPtr}, flags=PyBUF_SIMPLE)
end

function PyBuffer!(b::PyBuffer, o::Union{PyObject,PyPtr}, flags=PyBUF_SIMPLE)
# TODO change to `Ref{PyBuffer}` when 0.6 is dropped.
pydecref(b) # ensure b is properly released
@pycheckz ccall((@pysym :PyObject_GetBuffer), Cint,
(PyPtr, Any, Cint), o, b, flags)
(PyPtr, Ref{PyBuffer}, Cint), o, b, flags)
return b
end

"""
`isbuftype(o::Union{PyObject,PyPtr})`
Returns true if the python object `o` supports the buffer protocol as a strided
array. False if not.
"""
function isbuftype(o::Union{PyObject,PyPtr})
# like isbuftype, but modifies caller's PyBuffer
function isbuftype!(o::Union{PyObject,PyPtr}, b::PyBuffer)
# PyObject_CheckBuffer is defined in a header file here: https://github.com/python/cpython/blob/ef5ce884a41c8553a7eff66ebace908c1dcc1f89/Include/abstract.h#L510
# so we can't access it easily. It basically just checks if PyObject_GetBuffer exists
# So we'll just try call PyObject_GetBuffer and check for success/failure
b = PyBuffer()
ret = ccall((@pysym :PyObject_GetBuffer), Cint,
(PyPtr, Any, Cint), o, b, PyBUF_ND_STRIDED)
if ret != 0
pyerr_clear()
else
# handle pointer types
T, native_byteorder = array_format(b)
T <: Ptr && (ret = 1)
end
return ret == 0
end

"""
isbuftype(o::Union{PyObject,PyPtr})
Returns `true` if the python object `o` supports the buffer protocol as a strided
array. `false` if not.
"""
isbuftype(o::Union{PyObject,PyPtr}) = isbuftype!(o, PyBuffer())

#############################################################################

# recursive function to write buffer dimension by dimension, starting at
Expand Down Expand Up @@ -195,7 +191,8 @@ end
# ref: https://github.com/numpy/numpy/blob/v1.14.2/numpy/core/src/multiarray/buffer.c#L966

const standard_typestrs = Dict{String,DataType}(
"?"=>Bool, "P"=>Ptr{Cvoid},
"?"=>Bool,
"P"=>Ptr{Cvoid}, "O"=>PyPtr,
"b"=>Int8, "B"=>UInt8,
"h"=>Int16, "H"=>UInt16,
"i"=>Int32, "I"=>UInt32,
Expand All @@ -208,7 +205,8 @@ const standard_typestrs = Dict{String,DataType}(
"Zf"=>ComplexF32, "Zd"=>ComplexF64)

const native_typestrs = Dict{String,DataType}(
"?"=>Bool, "P"=>Ptr{Cvoid},
"?"=>Bool,
"P"=>Ptr{Cvoid}, "O"=>PyPtr,
"b"=>Int8, "B"=>UInt8,
"h"=>Cshort, "H"=>Cushort,
"i"=>Cint, "I"=>Cuint,
Expand Down
12 changes: 4 additions & 8 deletions src/pytype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,7 @@ function PyTypeObject!(init::Function, t::PyTypeObject, name::AbstractString, ba
if t.tp_new == C_NULL
t.tp_new = @pyglobal :PyType_GenericNew
end
# TODO change to `Ref{PyTypeObject}` when 0.6 is dropped.
@pycheckz ccall((@pysym :PyType_Ready), Cint, (Any,), t)
@pycheckz ccall((@pysym :PyType_Ready), Cint, (Ref{PyTypeObject},), t)
ccall((@pysym :Py_IncRef), Cvoid, (Any,), t)
return t
end
Expand Down Expand Up @@ -414,8 +413,7 @@ const Py_TPFLAGS_HAVE_STACKLESS_EXTENSION = Ref(0x00000000)
function pyjlwrap_type!(init::Function, to::PyTypeObject, name::AbstractString)
sz = sizeof(Py_jlWrap) + sizeof(PyPtr) # must be > base type
PyTypeObject!(to, name, sz) do t::PyTypeObject
# TODO change to `Ref{PyTypeObject}` when 0.6 is dropped.
t.tp_base = ccall(:jl_value_ptr, Ptr{Cvoid}, (Any,), jlWrapType)
t.tp_base = ccall(:jl_value_ptr, Ptr{Cvoid}, (Ref{PyTypeObject},), jlWrapType)
ccall((@pysym :Py_IncRef), Cvoid, (Any,), jlWrapType)
init(t)
end
Expand All @@ -426,9 +424,8 @@ pyjlwrap_type(init::Function, name::AbstractString) =

# Given a jlwrap type, create a new instance (and save value for gc)
function pyjlwrap_new(pyT::PyTypeObject, value::Any)
# TODO change to `Ref{PyTypeObject}` when 0.6 is dropped.
o = PyObject(@pycheckn ccall((@pysym :_PyObject_New),
PyPtr, (Any,), pyT))
PyPtr, (Ref{PyTypeObject},), pyT))
p = convert(Ptr{Ptr{Cvoid}}, PyPtr(o))
if isimmutable(value)
# It is undefined to call `pointer_from_objref` on immutable objects.
Expand All @@ -452,8 +449,7 @@ function pyjlwrap_new(x::Any)
pyjlwrap_new(jlWrapType, x)
end

# TODO change to `Ref{PyTypeObject}` when 0.6 is dropped.
is_pyjlwrap(o::PyObject) = jlWrapType.tp_new != C_NULL && ccall((@pysym :PyObject_IsInstance), Cint, (PyPtr, Any), o, jlWrapType) == 1
is_pyjlwrap(o::PyObject) = jlWrapType.tp_new != C_NULL && ccall((@pysym :PyObject_IsInstance), Cint, (PyPtr, Ref{PyTypeObject}), o, jlWrapType) == 1

################################################################
# Fallback conversion: if we don't have a better conversion function,
Expand Down

0 comments on commit 33f07eb

Please sign in to comment.