Skip to content

Commit

Permalink
Add complex str constructor (#502)
Browse files Browse the repository at this point in the history
  • Loading branch information
arshajii committed Dec 5, 2023
1 parent b4a3f89 commit d1cd21b
Show file tree
Hide file tree
Showing 3 changed files with 330 additions and 2 deletions.
91 changes: 91 additions & 0 deletions stdlib/internal/builtin.codon
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,97 @@ class float:

return result

@extend
class complex:
def _from_str(v: str) -> complex:
def parse_error():
raise ValueError("complex() arg is a malformed string")

buf = __array__[byte](32)
n = len(v)
need_dyn_alloc = n >= len(buf)

s = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr
str.memcpy(s, v.ptr, n)
s[n] = byte(0)

x = 0.0
y = 0.0
z = 0.0
got_bracket = False
start = s
end = cobj()

while str._isspace(s[0]):
s += 1

if s[0] == byte(40): # '('
got_bracket = True
s += 1
while str._isspace(s[0]):
s += 1

z = _C.strtod(s, __ptr__(end))

if end != s:
s = end

if s[0] == byte(43) or s[0] == byte(45): # '+' '-'
x = z
y = _C.strtod(s, __ptr__(end))

if end != s:
s = end
else:
y = 1.0 if s[0] == byte(43) else -1.0
s += 1

if not (s[0] == byte(106) or s[0] == byte(74)): # 'j' 'J'
if need_dyn_alloc:
free(s)
parse_error()

s += 1
elif s[0] == byte(106) or s[0] == byte(74): # 'j' 'J'
s += 1
y = z
else:
x = z
else:
if s[0] == byte(43) or s[0] == byte(45): # '+' '-'
y = 1.0 if s[0] == byte(43) else -1.0
s += 1
else:
y = 1.0

if not (s[0] == byte(106) or s[0] == byte(74)): # 'j' 'J'
if need_dyn_alloc:
free(s)
parse_error()

s += 1

while str._isspace(s[0]):
s += 1

if got_bracket:
if s[0] != byte(41): # ')'
if need_dyn_alloc:
free(s)
parse_error()
s += 1
while str._isspace(s[0]):
s += 1

if s - start != n:
if need_dyn_alloc:
free(s)
parse_error()

if need_dyn_alloc:
free(s)
return complex(x, y)

def _jit_display(x, s: Static[str], bundle: Set[str] = Set[str]()):
if isinstance(x, None):
return
Expand Down
8 changes: 6 additions & 2 deletions stdlib/internal/types/complex.codon
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ class complex:
def __new__() -> complex:
return (0.0, 0.0)

def __new__(other):
return other.__complex__()
def __new__(what):
# do not overload! (needed to avoid pyobj conversion)
if isinstance(what, str) or isinstance(what, Optional[str]):
return complex._from_str(what)
else:
return what.__complex__()

def __new__(real, imag) -> complex:
return (float(real), float(imag))
Expand Down
233 changes: 233 additions & 0 deletions test/stdlib/cmath_test.codon
Original file line number Diff line number Diff line change
Expand Up @@ -865,3 +865,236 @@ def test_complex64():


test_complex64()


def test_complex_from_string():
# for tests when string is not zero-terminated
def f(s):
return complex(s[1:-1])

def g(s):
return complex(' ' * 50 + s[1:-1] + ' ' * 50)

assert complex("1") == 1+0j
assert complex("1j") == 1j
assert complex("-1") == -1
assert complex("+1") == +1
assert complex("(1+2j)") == 1+2j
assert complex("(1.3+2.2j)") == 1.3+2.2j
assert complex("3.14+1J") == 3.14+1j
assert complex(" ( +3.14-6J )") == 3.14-6j
assert complex(" ( +3.14-J )") == 3.14-1j
assert complex(" ( +3.14+j )") == 3.14+1j
assert complex("J") == 1j
assert complex("( j )") == 1j
assert complex("+J") == 1j
assert complex("( -j)") == -1j
assert complex('1e-500') == 0.0 + 0.0j
assert complex('-1e-500j') == 0.0 - 0.0j
assert complex('-1e-500+1e-500j') == -0.0 + 0.0j
assert complex('1-1j') == 1.0 - 1j
assert complex('1J') == 1j

assert f("x1x") == 1+0j
assert f("x1jx") == 1j
assert f("x-1x") == -1
assert f("x+1x") == +1
assert f("x(1+2j)x") == 1+2j
assert f("x(1.3+2.2j)x") == 1.3+2.2j
assert f("x3.14+1Jx") == 3.14+1j
assert f("x ( +3.14-6J )x") == 3.14-6j
assert f("x ( +3.14-J )x") == 3.14-1j
assert f("x ( +3.14+j )x") == 3.14+1j
assert f("xJx") == 1j
assert f("x( j )x") == 1j
assert f("x+Jx") == 1j
assert f("x( -j)x") == -1j
assert f('x1e-500x') == 0.0 + 0.0j
assert f('x-1e-500jx') == 0.0 - 0.0j
assert f('x-1e-500+1e-500jx') == -0.0 + 0.0j
assert f('x1-1jx') == 1.0 - 1j
assert f('x1Jx') == 1j

assert g("x1x") == 1+0j
assert g("x1jx") == 1j
assert g("x-1x") == -1
assert g("x+1x") == +1
assert g("x(1+2j)x") == 1+2j
assert g("x(1.3+2.2j)x") == 1.3+2.2j
assert g("x3.14+1Jx") == 3.14+1j
assert g("x ( +3.14-6J )x") == 3.14-6j
assert g("x ( +3.14-J )x") == 3.14-1j
assert g("x ( +3.14+j )x") == 3.14+1j
assert g("xJx") == 1j
assert g("x( j )x") == 1j
assert g("x+Jx") == 1j
assert g("x( -j)x") == -1j
assert g('x1e-500x') == 0.0 + 0.0j
assert g('x-1e-500jx') == 0.0 - 0.0j
assert g('x-1e-500+1e-500jx') == -0.0 + 0.0j
assert g('x1-1jx') == 1.0 - 1j
assert g('x1Jx') == 1j

try:
complex("\0")
assert False
except ValueError:
pass

try:
complex("3\09")
assert False
except ValueError:
pass

try:
complex("1+")
assert False
except ValueError:
pass

try:
complex("1+1j+1j")
assert False
except ValueError:
pass

try:
complex("--")
assert False
except ValueError:
pass

try:
complex("(1+2j")
assert False
except ValueError:
pass

try:
complex("1+2j)")
assert False
except ValueError:
pass

try:
complex("1+(2j)")
assert False
except ValueError:
pass

try:
complex("(1+2j)123")
assert False
except ValueError:
pass

try:
complex("x")
assert False
except ValueError:
pass

try:
complex("1j+2")
assert False
except ValueError:
pass

try:
complex("1e1ej")
assert False
except ValueError:
pass

try:
complex("1e++1ej")
assert False
except ValueError:
pass

try:
complex(")1+2j(")
assert False
except ValueError:
pass

try:
complex("")
assert False
except ValueError:
pass

try:
f(" 1+2j")
assert False
except ValueError:
pass

try:
f("1..1j")
assert False
except ValueError:
pass

try:
f("1.11.1j")
assert False
except ValueError:
pass

try:
f("1e1.1j")
assert False
except ValueError:
pass

try:
f(" ")
assert False
except ValueError:
pass

try:
f(" J")
assert False
except ValueError:
pass

try:
g(" 1+2j")
assert False
except ValueError:
pass

try:
g("1..1j")
assert False
except ValueError:
pass

try:
g("1.11.1j")
assert False
except ValueError:
pass

try:
g("1e1.1j")
assert False
except ValueError:
pass

try:
g(" ")
assert False
except ValueError:
pass

try:
g(" J")
assert False
except ValueError:
pass

test_complex_from_string()

0 comments on commit d1cd21b

Please sign in to comment.