Skip to content

Commit

Permalink
support more magic methods
Browse files Browse the repository at this point in the history
related to #1
  • Loading branch information
kecnry committed Jun 17, 2017
1 parent 08b946e commit 2107f1d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 7 deletions.
10 changes: 10 additions & 0 deletions examples/comparisons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import nparray

a = nparray.linspace(0,1,10,False)
b = nparray.arange(0,1,0.1)

print a, a.array
print b, b.array
print a.array==b.array, a==b
print a==a.array
print 0 in a.array, 0 in a
63 changes: 56 additions & 7 deletions nparray/nparray.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,18 @@ def __getattr__(self, name):
# print "*** __getattr__", name
if name in ['_descriptors', '_validators']:
# then we need to actually get the attribute
return super(Array, self).__getattr__(name)
return super(ArrayWrapper, self).__getattr__(name)
# return self._descriptors
elif name in self._descriptors.keys():
# then get the item in the dictionary
return self._descriptors.get(name)
else:
# elif hasattr(self, name):
# return super(ArrayWrapper, self).__getattr__(name)
elif hasattr(self.array, name):
# then fallback on the underlying array object
return getattr(self.array, name)
else:
raise AttributeError("neither '{}' or '{}' have attribute '{}'".format(self.__class__.__name__.lower(), 'numpy.ndarray', name))

def __setattr__(self, name, value):
"""
Expand Down Expand Up @@ -154,11 +158,17 @@ def __repr__(self):
descriptors = " ".join(["{}={}".format(k,v) for k,v in self._descriptors.items()])
return "<{} {}>".format(self.__class__.__name__.lower(), descriptors)

def __eq__(self, other):
"""
determine eq based on contents of underlying array
"""
return self.array.__eq__(other.array) if isinstance(other, ArrayWrapper) else self.array.__eq__(other)
def __str__(self):
return self.array.__str__()

def __copy__(self):
return self.__class__(**self._descriptors)

def __deepcopy__(self):
return self.__copy__(self)

def copy(self):
return self.__copy__()

def to_array(self):
return Array(self.array)
Expand Down Expand Up @@ -201,6 +211,45 @@ def __div__(self, other):
def __rdiv__(self, other):
return self.__math__('__rdiv__', other)

def __pow__(self, other):
return self.__math__('__pow__', other)

def __rpow__(self, other):
return Array(self.array.__rpow__(other))

def __abs__(self):
return Array(abs(self.array))

def __len__(self):
return len(self.array)

def __comparison__(self, operator, other):
"""
determine comparisons based on the underyling arrays
"""
return getattr(self.array, operator)(other)

def __eq__(self, other):
return self.__comparison__('__eq__', other)

def __ne__(self, other):
return self.__comparison__('__ne__', other)

def __lt__(self, other):
return self.__comparison__('__lt__', other)

def __lte__(self, other):
return self.__comparison__('__le__', other)

def __gt__(self, other):
return self.__comparison__('__gt__', other)

def __ge__(self, other):
return self.__comparison__('__ge__', other)

def __contains__(self, other):
return self.__comparison__('__contains__', other)

class Array(ArrayWrapper):
def __init__(self, value):
super(Array, self).__init__(('value', value, is_iterable))
Expand Down

0 comments on commit 2107f1d

Please sign in to comment.