Skip to content

Commit

Permalink
support casting from strings whenever possible
Browse files Browse the repository at this point in the history
  • Loading branch information
kecnry committed Aug 16, 2019
1 parent 05c9825 commit 2d6756f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
66 changes: 64 additions & 2 deletions nparray/nparray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from collections import OrderedDict
import json
import sys

try:
from astropy import units
Expand All @@ -9,6 +10,9 @@
else:
_has_astropy = True

if sys.version_info[0] == 3:
unicode = str

################## VALIDATORS ###################

# these all must accept a single value and return a boolean if it matches the condition as well as any alterations to the value
Expand Down Expand Up @@ -42,23 +46,51 @@ def is_unit_or_unitstring_or_none(value):

def is_bool(value):
"""must be boolean"""
if isinstance(value, str) or isinstance(value, unicode):
if value.upper() == 'TRUE':
return True, True
elif value.upper() == 'FALSE':
return True, False
else:
return False, None

return isinstance(value, bool), value

def is_float(value):
"""must be a float"""
if isinstance(value, str) or isinstance(value, unicode):
try:
value = float(value)
except:
return False, None
else:
return True, value
return isinstance(value, float) or isinstance(value, int) or isinstance(value, np.float64), float(value)

def is_int(value):
"""must be an integer"""
if isinstance(value, str) or isinstance(value, unicode):
if "." in value:
return False, None

try:
value = int(float(value))
except:
return False, None
else:
return True, value
return isinstance(value, int), value

def is_int_positive(value):
"""must be a positive integer"""
return isinstance(value, int) and value > 0, value
_is_int, value = is_int(value)
return _is_int and value > 0, value

def is_int_positive_or_none(value):
"""must be a postive integer or None"""
return is_int_positive or value is None, value
if value is None:
return True, value
return is_int_positive(value)

def is_valid_shape(value):
"""must be a positive integer or a tuple/list of positive integers"""
Expand Down Expand Up @@ -381,6 +413,9 @@ def __contains__(self, other):
return self.__comparison__('__contains__', other)

class Array(ArrayWrapper):
"""
This is available as a top-level convenience function as <nparray.array>.
"""
def __init__(self, value, unit=None):
"""
This is available as a top-level convenience function as <nparray.array>.
Expand Down Expand Up @@ -424,6 +459,9 @@ def __setitem__(self, index, value):
self.value.__setitem__(index, value)

class Arange(ArrayWrapper):
"""
This is available as a top-level convenience function as <nparray.arange>.
"""
def __init__(self, start, stop, step, unit=None):
"""
This is available as a top-level convenience function as <nparray.arange>.
Expand Down Expand Up @@ -506,6 +544,9 @@ def __math__(self, operator, other):
raise ValueError("{} not supported with type {}".format(operator, type(other)))

class Linspace(ArrayWrapper):
"""
This is available as a top-level convenience function as <nparray.linspace>.
"""
def __init__(self, start, stop, num, endpoint=True, unit=None):
"""
This is available as a top-level convenience function as <nparray.linspace>.
Expand Down Expand Up @@ -582,6 +623,9 @@ def __math__(self, operator, other):
raise ValueError("{} not supported with type {}".format(operator, type(other)))

class Logspace(ArrayWrapper):
"""
This is available as a top-level convenience function as <nparray.logspace>.
"""
def __init__(self, start, stop, num, endpoint=True, base=10.0, unit=None):
"""
This is available as a top-level convenience function as <nparray.logspace>.
Expand Down Expand Up @@ -642,6 +686,9 @@ def _math__(self, operator, other):


class Geomspace(ArrayWrapper):
"""
This is available as a top-level convenience function as <nparray.geomspace>.
"""
def __init__(self, start, stop, num, endpoint=True, unit=None):
"""
This is available as a top-level convenience function as <nparray.geomspace>.
Expand Down Expand Up @@ -702,6 +749,10 @@ def __math__(self, operator, other):
raise ValueError("{} not supported with type {}".format(operator, type(other)))

class Full(ArrayWrapper):
"""
This is available as a top-level convenience function as <nparray.full> or
<nparray.full_like>.
"""
def __init__(self, shape, fill_value, unit=None):
"""
This is available as a top-level convenience function as <nparray.full>
Expand Down Expand Up @@ -789,6 +840,10 @@ def __math__(self, operator, other):


class Zeros(ArrayWrapper):
"""
This is available as a top-level convenience function as <nparray.zeros>
or <nparray.zeros_like>.
"""
def __init__(self, shape, unit=None):
"""
This is available as a top-level convenience function as <nparray.zeros>
Expand Down Expand Up @@ -882,6 +937,10 @@ def __math__(self, operator, other):
raise ValueError("{} not supported with type {}".format(operator, type(other)))

class Ones(ArrayWrapper):
"""
This is available as a top-level convenience function as <nparray.ones> or
<nparray.ones_like>.
"""
def __init__(self, shape, unit=None):
"""
This is available as a top-level convenience function as <nparray.ones>
Expand Down Expand Up @@ -975,6 +1034,9 @@ def __math__(self, operator, other):
raise ValueError("{} not supported with type {}".format(operator, type(other)))

class Eye(ArrayWrapper):
"""
This is available as a top-level convenience function as <nparray.eye>.
"""
def __init__(self, M, N=None, k=0, unit=None):
"""
This is available as a top-level convenience function as <nparray.eye>.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_create():

def test_create_errors():
# passing string
assert_raises(ValueError, npa.arange, 0, "1", 1)
assert_raises(ValueError, npa.arange, 0, "a", 1)

# not enough args
assert_raises(TypeError, npa.arange, 0, 1)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_linspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_create():

def test_create_errors():
# passing string
assert_raises(ValueError, npa.linspace, 0, "1", 11)
assert_raises(ValueError, npa.linspace, 0, "a", 11)

# not enough args
assert_raises(TypeError, npa.linspace, 0, 1)
Expand Down

0 comments on commit 2d6756f

Please sign in to comment.