Skip to content

Commit

Permalink
Implements partial substitution of Mul objects
Browse files Browse the repository at this point in the history
  • Loading branch information
eeshan9815 committed Feb 7, 2018
1 parent aa90434 commit d47f40b
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 0 deletions.
223 changes: 223 additions & 0 deletions symengine/subs.h
Expand Up @@ -276,6 +276,229 @@ class SubsVisitor : public BaseVisitor<SubsVisitor, XReplaceVisitor>
: BaseVisitor<SubsVisitor, XReplaceVisitor>(subs_dict_)
{
}
void bvisit(const Mul &x)
{
RCP<const Number> coef = x.get_coef();
map_basic_basic dict = x.get_dict();
map_basic_basic d;
bool fast_exec = false;
for (const auto &p : x.get_dict()) {
RCP<const Basic> factor_old;
if (eq(*p.second, *one)) {
factor_old = p.first;
} else {
factor_old = make_rcp<Pow>(p.first, p.second);
}
RCP<const Basic> factor = apply(factor_old);
if (factor == factor_old) {
Mul::dict_add_term_new(outArg(coef), d, p.second, p.first);
} else if (is_a_Number(*factor)) {
fast_exec = true;
if (down_cast<const Number &>(*factor).is_zero()) {
result_ = factor;
return;
}
imulnum(outArg(coef), rcp_static_cast<const Number>(factor));
} else if (is_a<Mul>(*factor)) {
fast_exec = true;
RCP<const Mul> tmp = rcp_static_cast<const Mul>(factor);
imulnum(outArg(coef), tmp->get_coef());
for (const auto &q : tmp->get_dict()) {
Mul::dict_add_term_new(outArg(coef), d, q.second, q.first);
}
} else {
fast_exec = true;
RCP<const Basic> exp, t;
Mul::as_base_exp(factor, outArg(exp), outArg(t));
Mul::dict_add_term_new(outArg(coef), d, exp, t);
}
}
if (fast_exec) {
result_ = Mul::from_dict(coef, std::move(d));
return;
}
for (const auto &iter : subs_dict_) {
d.clear();
bool exists = true;
auto sub1 = iter.first;
auto rep = iter.second;
if (is_a<Mul>(*sub1)) {
RCP<const Mul> subst = rcp_static_cast<const Mul>(sub1);
for (auto &p : subst->get_dict()) {
auto it = dict.find(p.first);
RCP<const Basic> diff_;
if (it != dict.end())
diff_ = sub(it->second, p.second);
if (it == dict.end()
|| down_cast<const Number &>(*diff_).is_negative()) {
exists = false;
break;
} else {
if (!down_cast<const Number &>(*diff_).is_zero())
Mul::dict_add_term_new(outArg(coef), d,
sub(it->second, p.second),
p.first);
}
}
if (exists) {
for (const auto &p : dict) {
auto it = subst->get_dict().find(p.first);
if (it == subst->get_dict().end())
Mul::dict_add_term_new(outArg(coef), d, p.second,
p.first);
}
if (is_a_Number(*rep)) {
if (down_cast<const Number &>(*rep).is_zero()) {
result_ = rep;
return;
}
imulnum(outArg(coef),
rcp_static_cast<const Number>(rep));
} else if (is_a<Mul>(*rep)) {
RCP<const Mul> tmp = rcp_static_cast<const Mul>(rep);
imulnum(outArg(coef), tmp->get_coef());
for (const auto &q : tmp->get_dict()) {
Mul::dict_add_term_new(outArg(coef), d, q.second,
q.first);
}
} else {
RCP<const Basic> exp, t;
Mul::as_base_exp(rep, outArg(exp), outArg(t));
Mul::dict_add_term_new(outArg(coef), d, exp, t);
}
} else
d = x.get_dict();
} else if (is_a<Pow>(*sub1)) {
RCP<const Pow> subst = rcp_static_cast<const Pow>(sub1);
auto sub1_exp = subst->get_exp();
auto sub1_base = subst->get_base();
exists = false;
if (is_a_Number(*sub1_exp)) {
for (const auto &p : dict) {
auto diff_ = sub(p.second, sub1_exp);
if (eq(*sub1_base, *(p.first))
and eq(*sub1_exp, *p.second)) {
exists = true;
} else if (eq(*sub1_base, *(p.first))
and down_cast<const Number &>(*diff_)
.is_positive()) {
exists = true;
Mul::dict_add_term_new(outArg(coef), d,
sub(p.second, sub1_exp),
p.first);
} else {
Mul::dict_add_term_new(outArg(coef), d, p.second,
p.first);
}
}
} else {
for (const auto &p : dict) {
if (eq(*sub1_base, *(p.first))
and eq(*sub1_exp, *p.second)) {
exists = true;
} else {
Mul::dict_add_term_new(outArg(coef), d, p.second,
p.first);
}
}
}
if (exists) {
if (is_a_Number(*rep)) {
if (down_cast<const Number &>(*rep).is_zero()) {
result_ = rep;
return;
}
imulnum(outArg(coef),
rcp_static_cast<const Number>(rep));
} else if (is_a<Mul>(*rep)) {
RCP<const Mul> tmp = rcp_static_cast<const Mul>(rep);
imulnum(outArg(coef), tmp->get_coef());
for (const auto &q : tmp->get_dict()) {
Mul::dict_add_term_new(outArg(coef), d, q.second,
q.first);
}
} else {
RCP<const Basic> exp, t;
Mul::as_base_exp(rep, outArg(exp), outArg(t));
Mul::dict_add_term_new(outArg(coef), d, exp, t);
}
} else
d = x.get_dict();
} else if (is_a<Symbol>(*sub1)) {
exists = false;
for (const auto &p : dict) {
if (eq(*sub1, *(p.first)) and eq(*one, *p.second)) {
exists = true;
} else if (eq(*sub1, *(p.first))
and not eq(*one, *p.second)) {
exists = true;
Mul::dict_add_term_new(outArg(coef), d,
sub(p.second, one), p.first);
} else {
Mul::dict_add_term_new(outArg(coef), d, p.second,
p.first);
}
}
if (exists) {
if (is_a_Number(*rep)) {
if (down_cast<const Number &>(*rep).is_zero()) {
result_ = rep;
return;
}
imulnum(outArg(coef),
rcp_static_cast<const Number>(rep));
} else if (is_a<Mul>(*rep)) {
RCP<const Mul> tmp = rcp_static_cast<const Mul>(rep);
imulnum(outArg(coef), tmp->get_coef());
for (const auto &q : tmp->get_dict()) {
Mul::dict_add_term_new(outArg(coef), d, q.second,
q.first);
}
} else {
RCP<const Basic> exp, t;
Mul::as_base_exp(rep, outArg(exp), outArg(t));
Mul::dict_add_term_new(outArg(coef), d, exp, t);
}
} else
d = x.get_dict();
} else {
exists = false;
for (const auto &p : dict) {
if (eq(*sub1, *(p.first))) {
exists = true;
} else {
Mul::dict_add_term_new(outArg(coef), d, p.second,
p.first);
}
}
if (exists) {
if (is_a_Number(*rep)) {
if (down_cast<const Number &>(*rep).is_zero()) {
result_ = rep;
return;
}
imulnum(outArg(coef),
rcp_static_cast<const Number>(rep));
} else if (is_a<Mul>(*rep)) {
RCP<const Mul> tmp = rcp_static_cast<const Mul>(rep);
imulnum(outArg(coef), tmp->get_coef());
for (const auto &q : tmp->get_dict()) {
Mul::dict_add_term_new(outArg(coef), d, q.second,
q.first);
}
} else {
RCP<const Basic> exp, t;
Mul::as_base_exp(rep, outArg(exp), outArg(t));
Mul::dict_add_term_new(outArg(coef), d, exp, t);
}
} else
d = x.get_dict();
}
dict.clear();
dict.insert(d.begin(), d.end());
}
result_ = Mul::from_dict(coef, std::move(d));
}

void bvisit(const Pow &x)
{
Expand Down
71 changes: 71 additions & 0 deletions symengine/tests/basic/test_subs.cpp
Expand Up @@ -152,6 +152,77 @@ TEST_CASE("Mul: subs", "[subs]")
r2 = z;
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(x, y)] = z;
r1 = mul(mul(x, y), z);
r2 = pow(z, i2);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(x, y)] = mul(y, z);
r1 = mul(mul(x, y), z);
r2 = mul(y, pow(z, i2));
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(x, y)] = i4;
r1 = mul(mul(x, y), z);
r2 = mul(z, i4);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(x, y)] = z;
r1 = add(add(mul(mul(pow(y, i2), x), i2), mul(i3, pow(x, i2))),
mul(i4, pow(y, i2)));
r2 = add(add(mul(i3, pow(x, i2)), mul(i4, pow(y, i2))), mul(mul(i2, y), z));
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(mul(x, y), z)] = i2;
r1 = add(
add(mul(mul(mul(pow(y, i3), x), i2), pow(z, i2)), mul(i3, pow(x, i2))),
mul(i4, pow(y, i2)));
r2 = add(add(mul(i3, pow(x, i2)), mul(i4, pow(y, i2))),
mul(mul(pow(y, i2), z), i4));
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(x, y)] = z;
r1 = mul(x, z);
r2 = mul(x, z);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[pow(x, y)] = z;
r1 = mul(z, pow(x, y));
r2 = pow(z, i2);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[pow(x, i2)] = z;
r1 = mul(z, pow(x, i2));
r2 = pow(z, i2);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(y, pow(z, i2))] = x;
r1 = add(add(mul(mul(x, y), pow(z, i2)), mul(mul(z, y), pow(x, i2))),
mul(y, pow(z, i3)));
r2 = add(add(pow(x, i2), mul(mul(z, y), pow(x, i2))), mul(x, z));
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[x] = i2;
r1 = mul(pow(x, i2), y);
r2 = mul(i4, y);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[x] = i2;
r1 = add(mul(pow(x, i2), y), mul(pow(x, i2), z));
r2 = add(mul(i4, y), mul(i4, z));
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[pow(x, y)] = z;
r1 = mul(i2, pow(x, y));
Expand Down

0 comments on commit d47f40b

Please sign in to comment.