Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements partial substitution of Mul objects #1395

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
232 changes: 232 additions & 0 deletions symengine/subs.h
Expand Up @@ -276,6 +276,238 @@ class SubsVisitor : public BaseVisitor<SubsVisitor, XReplaceVisitor>
: BaseVisitor<SubsVisitor, XReplaceVisitor>(subs_dict_)
{
}
void bvisit(const Mul &x)
{
RCP<const Number> coef = x.get_coef();
map_basic_basic dict = x.get_dict();
map_basic_basic d;
bool fast_exec = false;
for (const auto &p : x.get_dict()) {
RCP<const Basic> factor_old;
if (eq(*p.second, *one)) {
factor_old = p.first;
} else {
factor_old = make_rcp<Pow>(p.first, p.second);
}
RCP<const Basic> factor = apply(factor_old);
if (factor == factor_old) {
Mul::dict_add_term_new(outArg(coef), d, p.second, p.first);
} else if (is_a_Number(*factor)) {
fast_exec = true;
if (down_cast<const Number &>(*factor).is_zero()) {
result_ = factor;
return;
}
imulnum(outArg(coef), rcp_static_cast<const Number>(factor));
} else if (is_a<Mul>(*factor)) {
fast_exec = true;
RCP<const Mul> tmp = rcp_static_cast<const Mul>(factor);
imulnum(outArg(coef), tmp->get_coef());
for (const auto &q : tmp->get_dict()) {
Mul::dict_add_term_new(outArg(coef), d, q.second, q.first);
}
} else {
fast_exec = true;
RCP<const Basic> exp, t;
Mul::as_base_exp(factor, outArg(exp), outArg(t));
Mul::dict_add_term_new(outArg(coef), d, exp, t);
}
}
if (fast_exec and subs_dict_.size() == 1) {
result_ = Mul::from_dict(coef, std::move(d));
return;
}
for (const auto &iter : subs_dict_) {
d.clear();
bool exists = true;
auto sub1 = iter.first;
auto rep = iter.second;
if (is_a<Mul>(*sub1)) {
RCP<const Mul> subst = rcp_static_cast<const Mul>(sub1);
auto sub_coef = subst->get_coef();
if (not eq(*mod(*rcp_static_cast<const Integer>(coef),
*rcp_static_cast<const Integer>(sub_coef)),
*zero))
exists = false;
else {
for (auto &p : subst->get_dict()) {
auto it = dict.find(p.first);
RCP<const Basic> diff_;
if (it != dict.end())
diff_ = sub(it->second, p.second);
if (it == dict.end()
|| down_cast<const Number &>(*diff_)
.is_negative()) {
exists = false;
break;
} else {
if (!down_cast<const Number &>(*diff_).is_zero())
Mul::dict_add_term_new(
outArg(coef), d, sub(it->second, p.second),
p.first);
}
}
}
if (exists) {
idivnum(outArg(coef), sub_coef);
for (const auto &p : dict) {
auto it = subst->get_dict().find(p.first);
if (it == subst->get_dict().end())
Mul::dict_add_term_new(outArg(coef), d, p.second,
p.first);
}
if (is_a_Number(*rep)) {
if (down_cast<const Number &>(*rep).is_zero()) {
result_ = rep;
return;
}
imulnum(outArg(coef),
rcp_static_cast<const Number>(rep));
} else if (is_a<Mul>(*rep)) {
RCP<const Mul> tmp = rcp_static_cast<const Mul>(rep);
imulnum(outArg(coef), tmp->get_coef());
for (const auto &q : tmp->get_dict()) {
Mul::dict_add_term_new(outArg(coef), d, q.second,
q.first);
}
} else {
RCP<const Basic> exp, t;
Mul::as_base_exp(rep, outArg(exp), outArg(t));
Mul::dict_add_term_new(outArg(coef), d, exp, t);
}
} else
d = dict;
} else if (is_a<Pow>(*sub1)) {
RCP<const Pow> subst = rcp_static_cast<const Pow>(sub1);
auto sub1_exp = subst->get_exp();
auto sub1_base = subst->get_base();
exists = false;
if (is_a_Number(*sub1_exp)) {
for (const auto &p : dict) {
auto diff_ = sub(p.second, sub1_exp);
if (eq(*sub1_base, *(p.first))
and eq(*sub1_exp, *p.second)) {
exists = true;
} else if (eq(*sub1_base, *(p.first))
and down_cast<const Number &>(*diff_)
.is_positive()) {
exists = true;
Mul::dict_add_term_new(outArg(coef), d,
sub(p.second, sub1_exp),
p.first);
} else {
Mul::dict_add_term_new(outArg(coef), d, p.second,
p.first);
}
}
} else {
for (const auto &p : dict) {
if (eq(*sub1_base, *(p.first))
and eq(*sub1_exp, *p.second)) {
exists = true;
} else {
Mul::dict_add_term_new(outArg(coef), d, p.second,
p.first);
}
}
}
if (exists) {
if (is_a_Number(*rep)) {
if (down_cast<const Number &>(*rep).is_zero()) {
result_ = rep;
return;
}
imulnum(outArg(coef),
rcp_static_cast<const Number>(rep));
} else if (is_a<Mul>(*rep)) {
RCP<const Mul> tmp = rcp_static_cast<const Mul>(rep);
imulnum(outArg(coef), tmp->get_coef());
for (const auto &q : tmp->get_dict()) {
Mul::dict_add_term_new(outArg(coef), d, q.second,
q.first);
}
} else {
RCP<const Basic> exp, t;
Mul::as_base_exp(rep, outArg(exp), outArg(t));
Mul::dict_add_term_new(outArg(coef), d, exp, t);
}
} else
d = x.get_dict();
} else if (is_a<Symbol>(*sub1)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this part? Why is sub1 being a Symbol important? I thought the issue was that sub1 being a Mul isn't handled correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial substitution can happen for Symbol objects as well. For example, if the expression is x**2*y and sub1 is x, there is still partial substitution possible.
The first part only looks for exact matches.

Follow up question, should substitution happen repeatedly until not possible? For instance, if the expression is x**2*y**3 and x*y is to be replaced by z, should the result be x*y**2*z or y*z**2?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, (x**2*y).subs({x: z}) didn't work before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It did, but I've made the non fast_exec implementation complete. Should I remove these redundant parts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @isuruf

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's redundant, remove them

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eeshan9815, sorry about the delay. Can you remove redundant code here?

exists = false;
for (const auto &p : dict) {
if (eq(*sub1, *(p.first)) and eq(*one, *p.second)) {
exists = true;
} else if (eq(*sub1, *(p.first))
and not eq(*one, *p.second)) {
exists = true;
Mul::dict_add_term_new(outArg(coef), d,
sub(p.second, one), p.first);
} else {
Mul::dict_add_term_new(outArg(coef), d, p.second,
p.first);
}
}
if (exists) {
if (is_a_Number(*rep)) {
if (down_cast<const Number &>(*rep).is_zero()) {
result_ = rep;
return;
}
imulnum(outArg(coef),
rcp_static_cast<const Number>(rep));
} else if (is_a<Mul>(*rep)) {
RCP<const Mul> tmp = rcp_static_cast<const Mul>(rep);
imulnum(outArg(coef), tmp->get_coef());
for (const auto &q : tmp->get_dict()) {
Mul::dict_add_term_new(outArg(coef), d, q.second,
q.first);
}
} else {
RCP<const Basic> exp, t;
Mul::as_base_exp(rep, outArg(exp), outArg(t));
Mul::dict_add_term_new(outArg(coef), d, exp, t);
}
} else
d = x.get_dict();
} else {
exists = false;
for (const auto &p : dict) {
if (eq(*sub1, *(p.first))) {
exists = true;
} else {
Mul::dict_add_term_new(outArg(coef), d, p.second,
p.first);
}
}
if (exists) {
if (is_a_Number(*rep)) {
if (down_cast<const Number &>(*rep).is_zero()) {
result_ = rep;
return;
}
imulnum(outArg(coef),
rcp_static_cast<const Number>(rep));
} else if (is_a<Mul>(*rep)) {
RCP<const Mul> tmp = rcp_static_cast<const Mul>(rep);
imulnum(outArg(coef), tmp->get_coef());
for (const auto &q : tmp->get_dict()) {
Mul::dict_add_term_new(outArg(coef), d, q.second,
q.first);
}
} else {
RCP<const Basic> exp, t;
Mul::as_base_exp(rep, outArg(exp), outArg(t));
Mul::dict_add_term_new(outArg(coef), d, exp, t);
}
} else
d = x.get_dict();
}
dict.clear();
dict.insert(d.begin(), d.end());
}
result_ = Mul::from_dict(coef, std::move(d));
}

void bvisit(const Pow &x)
{
Expand Down
104 changes: 104 additions & 0 deletions symengine/tests/basic/test_subs.cpp
Expand Up @@ -130,9 +130,14 @@ TEST_CASE("Mul: subs", "[subs]")
RCP<const Basic> y = symbol("y");
RCP<const Basic> z = symbol("z");
RCP<const Basic> w = symbol("w");
RCP<const Basic> i1 = integer(1);
RCP<const Basic> i2 = integer(2);
RCP<const Basic> i3 = integer(3);
RCP<const Basic> i4 = integer(4);
RCP<const Basic> i5 = integer(5);
RCP<const Basic> i6 = integer(6);
RCP<const Basic> i15 = integer(15);
RCP<const Basic> i16 = integer(16);

RCP<const Basic> r1 = mul(x, y);
RCP<const Basic> r2 = pow(y, i2);
Expand All @@ -152,6 +157,77 @@ TEST_CASE("Mul: subs", "[subs]")
r2 = z;
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(x, y)] = z;
r1 = mul(mul(x, y), z);
r2 = pow(z, i2);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(x, y)] = mul(y, z);
r1 = mul(mul(x, y), z);
r2 = mul(y, pow(z, i2));
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(x, y)] = i4;
r1 = mul(mul(x, y), z);
r2 = mul(z, i4);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(x, y)] = z;
r1 = add(add(mul(mul(pow(y, i2), x), i2), mul(i3, pow(x, i2))),
mul(i4, pow(y, i2)));
r2 = add(add(mul(i3, pow(x, i2)), mul(i4, pow(y, i2))), mul(mul(i2, y), z));
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(mul(x, y), z)] = i2;
r1 = add(
add(mul(mul(mul(pow(y, i3), x), i2), pow(z, i2)), mul(i3, pow(x, i2))),
mul(i4, pow(y, i2)));
r2 = add(add(mul(i3, pow(x, i2)), mul(i4, pow(y, i2))),
mul(mul(pow(y, i2), z), i4));
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(x, y)] = z;
r1 = mul(x, z);
r2 = mul(x, z);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[pow(x, y)] = z;
r1 = mul(z, pow(x, y));
r2 = pow(z, i2);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[pow(x, i2)] = z;
r1 = mul(z, pow(x, i2));
r2 = pow(z, i2);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(y, pow(z, i2))] = x;
r1 = add(add(mul(mul(x, y), pow(z, i2)), mul(mul(z, y), pow(x, i2))),
mul(y, pow(z, i3)));
r2 = add(add(pow(x, i2), mul(mul(z, y), pow(x, i2))), mul(x, z));
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[x] = i2;
r1 = mul(pow(x, i2), y);
r2 = mul(i4, y);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[x] = i2;
r1 = add(mul(pow(x, i2), y), mul(pow(x, i2), z));
r2 = add(mul(i4, y), mul(i4, z));
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[pow(x, y)] = z;
r1 = mul(i2, pow(x, y));
Expand All @@ -176,6 +252,34 @@ TEST_CASE("Mul: subs", "[subs]")
r1 = div(one, mul(x, y));
d[x] = zero;
REQUIRE(eq(*r1->subs(d), *div(ComplexInf, y)));

d.clear();
d[mul(x, y)] = i2;
d[mul(z, w)] = i4;
r1 = mul(mul(mul(mul(x, y), z), w), i2);
r2 = i16;
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(x, y)] = i2;
d[z] = i1;
r1 = mul(mul(mul(mul(x, y), z), w), i2);
r2 = mul(w, i4);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(mul(x, y), i2)] = i5;
d[mul(mul(z, w), i2)] = i3;
r1 = mul(mul(mul(mul(x, y), z), w), i2);
r2 = mul(mul(z, w), i5);
REQUIRE(eq(*r1->subs(d), *r2));

d.clear();
d[mul(mul(x, y), i2)] = i5;
d[mul(z, w)] = i3;
r1 = mul(mul(mul(mul(x, y), z), w), i2);
r2 = i15;
REQUIRE(eq(*r1->subs(d), *r2));
}

TEST_CASE("Pow: subs", "[subs]")
Expand Down