diff --git a/symengine/cwrapper.cpp b/symengine/cwrapper.cpp index b144b6d299..8e9261faaa 100644 --- a/symengine/cwrapper.cpp +++ b/symengine/cwrapper.cpp @@ -30,6 +30,7 @@ using SymEngine::DenseMatrix; using SymEngine::down_cast; using SymEngine::function_symbol; using SymEngine::FunctionSymbol; +using SymEngine::has_symbol; using SymEngine::Integer; using SymEngine::integer_class; using SymEngine::LambdaRealDoubleVisitor; @@ -280,6 +281,11 @@ int number_is_complex(const basic s) return (int)((down_cast(*(s->m))).is_complex()); } +int basic_has_symbol(const basic e, const basic s) +{ + return (int)(has_symbol(*(e->m), *(s->m))); +} + CWRAPPER_OUTPUT_TYPE integer_set_si(basic s, long i) { CWRAPPER_BEGIN diff --git a/symengine/cwrapper.h b/symengine/cwrapper.h index 90aad38758..3ebb54eb23 100644 --- a/symengine/cwrapper.h +++ b/symengine/cwrapper.h @@ -168,6 +168,9 @@ int number_is_positive(const basic s); //! Returns 1 if s is complex; 0 otherwise int number_is_complex(const basic s); +//! Returns 1 if `e` contains the symbol `s`; 0 otherwise +int basic_has_symbol(const basic e, const basic s); + //! Assign to s, a long. CWRAPPER_OUTPUT_TYPE integer_set_si(basic s, long i); //! Assign to s, a ulong. diff --git a/symengine/tests/cwrapper/test_cwrapper.c b/symengine/tests/cwrapper/test_cwrapper.c index 7653c18217..f16e4a3996 100644 --- a/symengine/tests/cwrapper/test_cwrapper.c +++ b/symengine/tests/cwrapper/test_cwrapper.c @@ -55,9 +55,21 @@ void test_cwrapper() basic_str_free(s); integer_set_ui(e, 456); + SYMENGINE_C_ASSERT(basic_has_symbol(e, x) == 0); + SYMENGINE_C_ASSERT(basic_has_symbol(e, y) == 0); + SYMENGINE_C_ASSERT(basic_has_symbol(e, z) == 0); basic_add(e, e, x); + SYMENGINE_C_ASSERT(basic_has_symbol(e, x) == 1); + SYMENGINE_C_ASSERT(basic_has_symbol(e, y) == 0); + SYMENGINE_C_ASSERT(basic_has_symbol(e, z) == 0); basic_mul(e, e, y); + SYMENGINE_C_ASSERT(basic_has_symbol(e, x) == 1); + SYMENGINE_C_ASSERT(basic_has_symbol(e, y) == 1); + SYMENGINE_C_ASSERT(basic_has_symbol(e, z) == 0); basic_div(e, e, z); + SYMENGINE_C_ASSERT(basic_has_symbol(e, x) == 1); + SYMENGINE_C_ASSERT(basic_has_symbol(e, y) == 1); + SYMENGINE_C_ASSERT(basic_has_symbol(e, z) == 1); s = basic_str(e); SYMENGINE_C_ASSERT(strcmp(s, "y*(456 + x)/z") == 0); basic_str_free(s);