From d1cd21b3e3d8798d2e08f2770b83a99271e9337d Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Tue, 5 Dec 2023 15:48:34 -0500 Subject: [PATCH] Add complex str constructor (#502) --- stdlib/internal/builtin.codon | 91 +++++++++++ stdlib/internal/types/complex.codon | 8 +- test/stdlib/cmath_test.codon | 233 ++++++++++++++++++++++++++++ 3 files changed, 330 insertions(+), 2 deletions(-) diff --git a/stdlib/internal/builtin.codon b/stdlib/internal/builtin.codon index b5c21323..f37e93d4 100644 --- a/stdlib/internal/builtin.codon +++ b/stdlib/internal/builtin.codon @@ -402,6 +402,97 @@ class float: return result +@extend +class complex: + def _from_str(v: str) -> complex: + def parse_error(): + raise ValueError("complex() arg is a malformed string") + + buf = __array__[byte](32) + n = len(v) + need_dyn_alloc = n >= len(buf) + + s = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr + str.memcpy(s, v.ptr, n) + s[n] = byte(0) + + x = 0.0 + y = 0.0 + z = 0.0 + got_bracket = False + start = s + end = cobj() + + while str._isspace(s[0]): + s += 1 + + if s[0] == byte(40): # '(' + got_bracket = True + s += 1 + while str._isspace(s[0]): + s += 1 + + z = _C.strtod(s, __ptr__(end)) + + if end != s: + s = end + + if s[0] == byte(43) or s[0] == byte(45): # '+' '-' + x = z + y = _C.strtod(s, __ptr__(end)) + + if end != s: + s = end + else: + y = 1.0 if s[0] == byte(43) else -1.0 + s += 1 + + if not (s[0] == byte(106) or s[0] == byte(74)): # 'j' 'J' + if need_dyn_alloc: + free(s) + parse_error() + + s += 1 + elif s[0] == byte(106) or s[0] == byte(74): # 'j' 'J' + s += 1 + y = z + else: + x = z + else: + if s[0] == byte(43) or s[0] == byte(45): # '+' '-' + y = 1.0 if s[0] == byte(43) else -1.0 + s += 1 + else: + y = 1.0 + + if not (s[0] == byte(106) or s[0] == byte(74)): # 'j' 'J' + if need_dyn_alloc: + free(s) + parse_error() + + s += 1 + + while str._isspace(s[0]): + s += 1 + + if got_bracket: + if s[0] != byte(41): # ')' + if need_dyn_alloc: + free(s) + parse_error() + s += 1 + while str._isspace(s[0]): + s += 1 + + if s - start != n: + if need_dyn_alloc: + free(s) + parse_error() + + if need_dyn_alloc: + free(s) + return complex(x, y) + def _jit_display(x, s: Static[str], bundle: Set[str] = Set[str]()): if isinstance(x, None): return diff --git a/stdlib/internal/types/complex.codon b/stdlib/internal/types/complex.codon index d7e85934..207ecb7e 100644 --- a/stdlib/internal/types/complex.codon +++ b/stdlib/internal/types/complex.codon @@ -13,8 +13,12 @@ class complex: def __new__() -> complex: return (0.0, 0.0) - def __new__(other): - return other.__complex__() + def __new__(what): + # do not overload! (needed to avoid pyobj conversion) + if isinstance(what, str) or isinstance(what, Optional[str]): + return complex._from_str(what) + else: + return what.__complex__() def __new__(real, imag) -> complex: return (float(real), float(imag)) diff --git a/test/stdlib/cmath_test.codon b/test/stdlib/cmath_test.codon index 834a6d73..d2e8926e 100644 --- a/test/stdlib/cmath_test.codon +++ b/test/stdlib/cmath_test.codon @@ -865,3 +865,236 @@ def test_complex64(): test_complex64() + + +def test_complex_from_string(): + # for tests when string is not zero-terminated + def f(s): + return complex(s[1:-1]) + + def g(s): + return complex(' ' * 50 + s[1:-1] + ' ' * 50) + + assert complex("1") == 1+0j + assert complex("1j") == 1j + assert complex("-1") == -1 + assert complex("+1") == +1 + assert complex("(1+2j)") == 1+2j + assert complex("(1.3+2.2j)") == 1.3+2.2j + assert complex("3.14+1J") == 3.14+1j + assert complex(" ( +3.14-6J )") == 3.14-6j + assert complex(" ( +3.14-J )") == 3.14-1j + assert complex(" ( +3.14+j )") == 3.14+1j + assert complex("J") == 1j + assert complex("( j )") == 1j + assert complex("+J") == 1j + assert complex("( -j)") == -1j + assert complex('1e-500') == 0.0 + 0.0j + assert complex('-1e-500j') == 0.0 - 0.0j + assert complex('-1e-500+1e-500j') == -0.0 + 0.0j + assert complex('1-1j') == 1.0 - 1j + assert complex('1J') == 1j + + assert f("x1x") == 1+0j + assert f("x1jx") == 1j + assert f("x-1x") == -1 + assert f("x+1x") == +1 + assert f("x(1+2j)x") == 1+2j + assert f("x(1.3+2.2j)x") == 1.3+2.2j + assert f("x3.14+1Jx") == 3.14+1j + assert f("x ( +3.14-6J )x") == 3.14-6j + assert f("x ( +3.14-J )x") == 3.14-1j + assert f("x ( +3.14+j )x") == 3.14+1j + assert f("xJx") == 1j + assert f("x( j )x") == 1j + assert f("x+Jx") == 1j + assert f("x( -j)x") == -1j + assert f('x1e-500x') == 0.0 + 0.0j + assert f('x-1e-500jx') == 0.0 - 0.0j + assert f('x-1e-500+1e-500jx') == -0.0 + 0.0j + assert f('x1-1jx') == 1.0 - 1j + assert f('x1Jx') == 1j + + assert g("x1x") == 1+0j + assert g("x1jx") == 1j + assert g("x-1x") == -1 + assert g("x+1x") == +1 + assert g("x(1+2j)x") == 1+2j + assert g("x(1.3+2.2j)x") == 1.3+2.2j + assert g("x3.14+1Jx") == 3.14+1j + assert g("x ( +3.14-6J )x") == 3.14-6j + assert g("x ( +3.14-J )x") == 3.14-1j + assert g("x ( +3.14+j )x") == 3.14+1j + assert g("xJx") == 1j + assert g("x( j )x") == 1j + assert g("x+Jx") == 1j + assert g("x( -j)x") == -1j + assert g('x1e-500x') == 0.0 + 0.0j + assert g('x-1e-500jx') == 0.0 - 0.0j + assert g('x-1e-500+1e-500jx') == -0.0 + 0.0j + assert g('x1-1jx') == 1.0 - 1j + assert g('x1Jx') == 1j + + try: + complex("\0") + assert False + except ValueError: + pass + + try: + complex("3\09") + assert False + except ValueError: + pass + + try: + complex("1+") + assert False + except ValueError: + pass + + try: + complex("1+1j+1j") + assert False + except ValueError: + pass + + try: + complex("--") + assert False + except ValueError: + pass + + try: + complex("(1+2j") + assert False + except ValueError: + pass + + try: + complex("1+2j)") + assert False + except ValueError: + pass + + try: + complex("1+(2j)") + assert False + except ValueError: + pass + + try: + complex("(1+2j)123") + assert False + except ValueError: + pass + + try: + complex("x") + assert False + except ValueError: + pass + + try: + complex("1j+2") + assert False + except ValueError: + pass + + try: + complex("1e1ej") + assert False + except ValueError: + pass + + try: + complex("1e++1ej") + assert False + except ValueError: + pass + + try: + complex(")1+2j(") + assert False + except ValueError: + pass + + try: + complex("") + assert False + except ValueError: + pass + + try: + f(" 1+2j") + assert False + except ValueError: + pass + + try: + f("1..1j") + assert False + except ValueError: + pass + + try: + f("1.11.1j") + assert False + except ValueError: + pass + + try: + f("1e1.1j") + assert False + except ValueError: + pass + + try: + f(" ") + assert False + except ValueError: + pass + + try: + f(" J") + assert False + except ValueError: + pass + + try: + g(" 1+2j") + assert False + except ValueError: + pass + + try: + g("1..1j") + assert False + except ValueError: + pass + + try: + g("1.11.1j") + assert False + except ValueError: + pass + + try: + g("1e1.1j") + assert False + except ValueError: + pass + + try: + g(" ") + assert False + except ValueError: + pass + + try: + g(" J") + assert False + except ValueError: + pass + +test_complex_from_string()