Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Complex Arguments to Exp #1332

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 52 additions & 0 deletions symengine/pow.cpp
Expand Up @@ -87,6 +87,25 @@ int Pow::compare(const Basic &o) const
return base_cmp;
}

bool exp_mul_helper(const RCP<const Basic> &b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this function returning -1, 0, 1,2,3 where -1 is when it is not what we want and the others n in the function beloe

{
if (is_a_Complex(*down_cast<const Mul &>(*b).get_coef())) {
const Mul &s = down_cast<const Mul &>(*b);
const map_basic_basic &dict = s.get_dict();
RCP<const Number> coef
= down_cast<const ComplexBase &>(*s.get_coef()).imaginary_part();
RCP<const Basic> arg = mul(coef, integer(2));
if (dict.size() == 1 and is_a<Integer>(*arg)) {
for (const auto &p : dict) {
if (eq(*p.first, *pi) and eq(*p.second, *one)) {
return true;
}
}
}
}
return false;
}

RCP<const Basic> pow(const RCP<const Basic> &a, const RCP<const Basic> &b)
{
if (is_a_Number(*b) and down_cast<const Number &>(*b).is_zero()) {
Expand Down Expand Up @@ -164,6 +183,39 @@ RCP<const Basic> pow(const RCP<const Basic> &a, const RCP<const Basic> &b)
RCP<const Pow> A = rcp_static_cast<const Pow>(a);
return pow(A->get_base(), mul(A->get_exp(), b));
}
if (eq(*a, *E)) {
if (is_a<Mul>(*b) and exp_mul_helper(b)) {
RCP<const Number> coef = down_cast<const ComplexBase &>(
*down_cast<const Mul &>(*b).get_coef())
.imaginary_part();
RCP<const Basic> arg = mul(coef, integer(2));
long n = ((down_cast<const Integer &>(*arg).as_int() % 4) + 4) % 4;
if (!n) {
return one;
} else if (n == 1) {
return I;
} else if (n == 2) {
return minus_one;
} else {
return Complex::from_two_nums(*zero, *minus_one);
}
} else if (is_a<Add>(*b)) {
const umap_basic_num &dict = down_cast<const Add &>(*b).get_dict();
umap_basic_num new_dict;
RCP<const Number> coef = down_cast<const Add &>(*b).get_coef();
RCP<const Basic> s = one;
for (const auto &p : dict) {
if (eq(*p.first, *pi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To do this, you don't need to iterate the dictionary. Just use dict.find to find the value pi.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way I can remove that particular entry of pi as well? That way, there'd be no need of iteration at all.

and exp_mul_helper(mul(p.first, p.second))) {
s = exp(mul(p.first, p.second));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can a call to exp be avoided here?

Copy link
Member Author

@ShikharJ ShikharJ Oct 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the expense of code duplication it can be. I can't seem to think of another way.

} else {
Add::dict_add_term(new_dict, p.second, p.first);
}
}
return mul(s, make_rcp<const Pow>(
a, Add::from_dict(coef, std::move(new_dict))));
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to carefully benchmark these two if branches, before and after this PR. So the first if clause would run for something like E^(a*b). The second one for E^(a+b). Correct?

If so, then we need to benchmark these two things and similar things a lot.

return make_rcp<const Pow>(a, b);
}

Expand Down
38 changes: 38 additions & 0 deletions symengine/tests/basic/test_arit.cpp
Expand Up @@ -32,6 +32,7 @@ using SymEngine::sub;
using SymEngine::exp;
using SymEngine::E;
using SymEngine::Rational;
using SymEngine::rational;
using SymEngine::Complex;
using SymEngine::Number;
using SymEngine::I;
Expand Down Expand Up @@ -987,6 +988,43 @@ TEST_CASE("Pow: arit", "[arit]")
r1 = pow(mul(sqrt(mul(y, x)), x), i2);
r2 = mul(pow(x, i3), y);
REQUIRE(eq(*r1, *r2));

r1 = exp(mul(I, x));
r2 = pow(E, mul(I, x));
REQUIRE(eq(*r1, *r2));

r1 = exp(mul(I, pi));
r2 = pow(E, mul(I, pi));
REQUIRE(eq(*r1, *r2));

r1 = exp(mul(I, pi));
REQUIRE(eq(*r1, *minus_one));

r1 = exp(mul(mul(I, minus_one), pi));
std::cout << *r1 << std::endl;
REQUIRE(eq(*r1, *minus_one));

r1 = exp(div(mul(I, pi), integer(2)));
REQUIRE(eq(*r1, *I));

r1 = exp(mul(div(mul(I, pi), integer(2)), minus_one));
REQUIRE(eq(*r1, *mul(I, minus_one)));

r1 = div(exp(mul(mul(I, pi), x)), exp(x));
REQUIRE(is_a<Pow>(*r1));
REQUIRE(eq(*down_cast<const Pow &>(*r1).get_base(), *E));
REQUIRE(eq(*down_cast<const Pow &>(*r1).get_exp(),
*add(mul(minus_one, x), mul(x, mul(I, pi)))));

r1 = exp(add(div(mul(I, pi), integer(2)), x));
r2 = mul(I, exp(x));
REQUIRE(eq(*r1, *r2));

r1 = exp(add(div(mul(I, pi), integer(10)), x));
REQUIRE(is_a<Pow>(*r1));
REQUIRE(eq(*down_cast<const Pow &>(*r1).get_base(), *E));
REQUIRE(eq(*down_cast<const Pow &>(*r1).get_exp(),
*add(x, mul(rational(1, 10), mul(I, pi)))));
}

TEST_CASE("Log: arit", "[arit]")
Expand Down