Skip to content

Commit

Permalink
Document and refactor KBO::State a bit; add check for global ordering…
Browse files Browse the repository at this point in the history
… when manipulating KBO weight in Term
  • Loading branch information
mezpusz committed Apr 29, 2024
1 parent e3f9e87 commit b7b6b73
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 107 deletions.
252 changes: 147 additions & 105 deletions Kernel/KBO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ using namespace Shell;

/**
* Class to represent the current state of the KBO comparison.
* Based on Bernd Loechner's "Things to Know when Implementing KBO"
* (https://doi.org/10.1007/s10817-006-9031-4)
* @since 30/04/2008 flight Brussels-Tel Aviv
*/
class KBO::State
Expand All @@ -59,8 +61,24 @@ class KBO::State
_varDiffs.reset();
}

template<bool unidirectional>
Result traverseLex(AppliedTerm t1, AppliedTerm t2);
/**
* Lexicographic traversal of two terms with same top symbol,
* i.e. traversing their symbols in lockstep, as descibed in
* the Loechner et al. paper above. It performs a bidirectional
* comparison between the two terms, i.e. we can get any value
* of @b Result.
*/
Result traverseLexBidir(AppliedTerm t1, AppliedTerm t2);
/**
* Optimised, unidirectional version of @b traverseLexBidir
* where we only care about @b GREATER and @b EQUAL, otherwise
* it returns as early as possible with @b INCOMPARABLE.
*/
Result traverseLexUnidir(AppliedTerm t1, AppliedTerm t2);
/**
* Performs a non-lexicographic (i.e. non-lockstep) traversal
* of two terms in case their top symbols are not the same.
*/
template<bool unidirectional>
Result traverseNonLex(AppliedTerm t1, AppliedTerm t2);

Expand Down Expand Up @@ -231,8 +249,7 @@ void KBO::State::traverse(AppliedTerm tt)
}
}

template<bool unidirectional>
Ordering::Result KBO::State::traverseLex(AppliedTerm tl1, AppliedTerm tl2)
Ordering::Result KBO::State::traverseLexBidir(AppliedTerm tl1, AppliedTerm tl2)
{
ASS(tl1.term.isTerm() && tl2.term.isTerm());
auto t1 = tl1.term.term();
Expand All @@ -254,20 +271,71 @@ Ordering::Result KBO::State::traverseLex(AppliedTerm tl1, AppliedTerm tl2)
auto [ss,ssAboveVar] = stack.pop(); // tl1 subterm
if(ss->isEmpty()) {
ASS(tt->isEmpty());

if constexpr (unidirectional) {
if (!checkVars()) {
return INCOMPARABLE;
}
} else {
depth--;
if(_lexResult!=EQUAL && depth<lexValidDepth) {
lexValidDepth=depth;
if(_weightDiff!=0) {
_lexResult=_weightDiff>0 ? GREATER : LESS;
}
_lexResult=applyVariableCondition(_lexResult);
depth--;
if(_lexResult!=EQUAL && depth<lexValidDepth) {
lexValidDepth=depth;
if(_weightDiff!=0) {
_lexResult=_weightDiff>0 ? GREATER : LESS;
}
_lexResult=applyVariableCondition(_lexResult);
}
continue;
}

stack.push(make_pair(ss->next(),ssAboveVar));
stack.push(make_pair(tt->next(),ttAboveVar));

AppliedTerm s(*ss,tl1.applicator,ssAboveVar);
AppliedTerm t(*tt,tl2.applicator,ttAboveVar);

if(s.equalsShallow(t)) {
//if content is the same, neither weightDiff nor varDiffs would change
continue;
}
if(TermList::sameTopFunctor(s.term,t.term)) {
ASS(s.term.isTerm());
ASS(t.term.isTerm());
ASS(s.term.term()->arity());
stack.push(make_pair(s.term.term()->args(),s.aboveVar));
stack.push(make_pair(t.term.term()->args(),t.aboveVar));
depth++;
} else {
traverse<1,/*unidirectional=*/false>(s);
traverse<-1,/*unidirectional=*/false>(t);
if(_lexResult==EQUAL) {
_lexResult=innerResult(s.term, t.term);
lexValidDepth=depth;
ASS(_lexResult!=EQUAL);
ASS(_lexResult!=GREATER_EQ);
ASS(_lexResult!=LESS_EQ);
}
}
}
return result(tl1,tl2);
}

Ordering::Result KBO::State::traverseLexUnidir(AppliedTerm tl1, AppliedTerm tl2)
{
ASS(tl1.term.isTerm() && tl2.term.isTerm());
auto t1 = tl1.term.term();
auto t2 = tl2.term.term();

ASS(t1->functor()==t2->functor());
ASS(t1->arity());
ASS_EQ(_lexResult, EQUAL);

static Stack<pair<const TermList*,bool>> stack(32);
stack.reset();
stack.push(make_pair(t1->args(),tl1.aboveVar));
stack.push(make_pair(t2->args(),tl2.aboveVar));
while(!stack.isEmpty()) {
auto [tt,ttAboveVar] = stack.pop(); // tl2 subterm
auto [ss,ssAboveVar] = stack.pop(); // tl1 subterm
if(ss->isEmpty()) {
ASS(tt->isEmpty());

if (!checkVars()) {
return INCOMPARABLE;
}
continue;
}
Expand All @@ -282,92 +350,67 @@ Ordering::Result KBO::State::traverseLex(AppliedTerm tl1, AppliedTerm tl2)
//if content is the same, neither weightDiff nor varDiffs would change
continue;
}
if constexpr (unidirectional) {
if (_lexResult==EQUAL) {
auto ssw = _kbo.computeWeight(s);
auto ttw = _kbo.computeWeight(t);
if (ssw < ttw) {
if (_lexResult==EQUAL) {
auto ssw = _kbo.computeWeight(s);
auto ttw = _kbo.computeWeight(t);
if (ssw < ttw) {
return INCOMPARABLE;
}
if (ssw > ttw) {
traverse<1,/*unidirectional=*/true>(s);
traverse<-1,/*unidirectional=*/true>(t);
if (!checkVars()) {
return INCOMPARABLE;
}
if (ssw > ttw) {
traverse<1,unidirectional>(s);
traverse<-1,unidirectional>(t);
if (!checkVars()) {
return INCOMPARABLE;
}
_lexResult = INCOMPARABLE;
continue;
_lexResult = INCOMPARABLE;
continue;
}
// ssw == ttw
if (s.term.isVar()) {
return INCOMPARABLE;
}
if (t.term.isVar()) {
if (!s.containsVar(t.term)) {
return INCOMPARABLE;
}
// ssw == ttw
if (s.term.isVar()) {
_lexResult = INCOMPARABLE;
continue;
}
Result comp = s.term.term()->isSort()
? _kbo.compareTypeConPrecedences(s.term.term()->functor(),t.term.term()->functor())
: _kbo.compareFunctionPrecedences(s.term.term()->functor(),t.term.term()->functor());
switch (comp)
{
case Ordering::LESS:
case Ordering::LESS_EQ: {
return INCOMPARABLE;
}
if (t.term.isVar()) {
if (!s.containsVar(t.term)) {
case Ordering::GREATER:
case Ordering::GREATER_EQ: {
traverse<1,/*unidirectional=*/true>(s);
traverse<-1,/*unidirectional=*/true>(t);
if (!checkVars()) {
return INCOMPARABLE;
}
_lexResult = INCOMPARABLE;
continue;
break;
}
Result comp = s.term.term()->isSort()
? _kbo.compareTypeConPrecedences(s.term.term()->functor(),t.term.term()->functor())
: _kbo.compareFunctionPrecedences(s.term.term()->functor(),t.term.term()->functor());
switch (comp)
{
case Ordering::LESS:
case Ordering::LESS_EQ: {
return INCOMPARABLE;
}
case Ordering::GREATER:
case Ordering::GREATER_EQ: {
traverse<1,unidirectional>(s);
traverse<-1,unidirectional>(t);
if (!checkVars()) {
return INCOMPARABLE;
}
_lexResult = INCOMPARABLE;
break;
}
case Ordering::EQUAL: {
stack.push(make_pair(s.term.term()->args(),s.aboveVar));
stack.push(make_pair(t.term.term()->args(),t.aboveVar));
break;
}
default: ASSERTION_VIOLATION;
case Ordering::EQUAL: {
stack.push(make_pair(s.term.term()->args(),s.aboveVar));
stack.push(make_pair(t.term.term()->args(),t.aboveVar));
break;
}
} else {
traverse<1,unidirectional>(s);
traverse<-1,unidirectional>(t);
default: ASSERTION_VIOLATION;
}
} else {
if(TermList::sameTopFunctor(s.term,t.term)) {
ASS(s.term.isTerm());
ASS(t.term.isTerm());
ASS(s.term.term()->arity());
stack.push(make_pair(s.term.term()->args(),s.aboveVar));
stack.push(make_pair(t.term.term()->args(),t.aboveVar));
depth++;
} else {
traverse<1,unidirectional>(s);
traverse<-1,unidirectional>(t);
if(_lexResult==EQUAL) {
_lexResult=innerResult(s.term, t.term);
lexValidDepth=depth;
ASS(_lexResult!=EQUAL);
ASS(_lexResult!=GREATER_EQ);
ASS(_lexResult!=LESS_EQ);
}
}
traverse<1,/*unidirectional=*/true>(s);
traverse<-1,/*unidirectional=*/true>(t);
}
}
if constexpr (unidirectional) {
if (_lexResult==EQUAL) {
return EQUAL;
}
return checkVars() ? GREATER : INCOMPARABLE;
} else {
return result(tl1,tl2);
if (_lexResult==EQUAL) {
return EQUAL;
}
return checkVars() ? GREATER : INCOMPARABLE;
}

template<bool unidirectional>
Expand Down Expand Up @@ -774,23 +817,22 @@ Ordering::Result KBO::comparePredicates(Literal* l1, Literal* l2) const
//this is to make sure _state isn't used while we're using it
_state=0;
#endif
constexpr bool unidirectional = false;
state->init();
if(p1!=p2) {
TermList* ts;
ts=l1->args();
while(!ts->isEmpty()) {
state->traverse<1,unidirectional>(AppliedTerm(*ts));
state->traverse<1,/*unidirectional=*/false>(AppliedTerm(*ts));
ts=ts->next();
}
ts=l2->args();
while(!ts->isEmpty()) {
state->traverse<-1,unidirectional>(AppliedTerm(*ts));
state->traverse<-1,/*unidirectional=*/false>(AppliedTerm(*ts));
ts=ts->next();
}
res=state->result(AppliedTerm(TermList(l1)),AppliedTerm(TermList(l2)));
} else {
res=state->traverseLex<unidirectional>(AppliedTerm(TermList(l1)),AppliedTerm(TermList(l2)));
res=state->traverseLexBidir(AppliedTerm(TermList(l1)),AppliedTerm(TermList(l2)));
}

#if VDEBUG
Expand Down Expand Up @@ -829,13 +871,12 @@ Ordering::Result KBO::compare(AppliedTerm tl1, AppliedTerm tl2) const
_state=0;
#endif

constexpr bool unidirectional = false;
state->init();
Result res;
if(t1->functor()==t2->functor()) {
res = state->traverseLex<unidirectional>(tl1,tl2);
res = state->traverseLexBidir(tl1,tl2);
} else {
res = state->traverseNonLex<unidirectional>(tl1,tl2);
res = state->traverseNonLex</*unidirectional=*/false>(tl1,tl2);
}
#if VDEBUG
_state=state;
Expand Down Expand Up @@ -875,12 +916,11 @@ Ordering::Result KBO::isGreaterOrEq(AppliedTerm tl1, AppliedTerm tl2) const
_state=0;
#endif

constexpr bool unidirectional = true;
state->init();
Result res;
if (w1>w2) {
// traverse variables
res = state->traverseNonLex<unidirectional>(tl1,tl2);
res = state->traverseNonLex</*unidirectional=*/true>(tl1,tl2);
#if VDEBUG
_state=state;
#endif
Expand All @@ -899,11 +939,11 @@ Ordering::Result KBO::isGreaterOrEq(AppliedTerm tl1, AppliedTerm tl2) const
}
case Ordering::GREATER:
case Ordering::GREATER_EQ: {
res = state->traverseNonLex<unidirectional>(tl1,tl2);
res = state->traverseNonLex</*unidirectional=*/true>(tl1,tl2);
break;
}
case Ordering::EQUAL: {
res = state->traverseLex<unidirectional>(tl1,tl2);
res = state->traverseLexUnidir(tl1,tl2);
break;
}
case Ordering::INCOMPARABLE:
Expand Down Expand Up @@ -939,8 +979,10 @@ unsigned KBO::computeWeight(AppliedTerm tt) const
if (tt.term.isVar()) {
return _funcWeights._specialWeights._variableWeight;
}
if (!tt.aboveVar && tt.term.term()->kboWeight()!=-1) {
return tt.term.term()->kboWeight();
const bool useCache = (tryGetGlobalOrdering() == this);

if (!tt.aboveVar && useCache && tt.term.term()->kboWeight(this)!=-1) {
return tt.term.term()->kboWeight(this);
}
// stack of [non-zero arity terms, current argument position, accumulated weight]
struct State {
Expand All @@ -958,17 +1000,17 @@ unsigned KBO::computeWeight(AppliedTerm tt) const

if (t.term.isVar()) {
curr.weight += _funcWeights._specialWeights._variableWeight;
} else if (!t.aboveVar && t.term.term()->kboWeight()!=-1) {
curr.weight += t.term.term()->kboWeight();
} else if (!t.aboveVar && useCache && t.term.term()->kboWeight(this)!=-1) {
curr.weight += t.term.term()->kboWeight(this);
} else {
recState.push(State{ t, 0, (unsigned)symbolWeight(t.term.term()) });
}

} else {

auto orig = recState.pop();
if (!orig.t.aboveVar) {
const_cast<Term*>(orig.t.term.term())->setKboWeight(orig.weight);
if (!orig.t.aboveVar && useCache) {
const_cast<Term*>(orig.t.term.term())->setKboWeight(orig.weight, this);
}
if (recState.isEmpty()) {
return orig.weight;
Expand Down
6 changes: 6 additions & 0 deletions Kernel/Term.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,9 @@ Term::Term(const Term& t) throw()
_isTwoVarEquality(0),
_weight(0),
_kboWeight(-1),
#if VDEBUG
_kboInstance(nullptr),
#endif
_vars(0)
{
ASS(!isSpecial()); //we do not copy special terms
Expand Down Expand Up @@ -1527,6 +1530,9 @@ Term::Term() throw()
_isTwoVarEquality(0),
_weight(0),
_kboWeight(-1),
#if VDEBUG
_kboInstance(nullptr),
#endif
_maxRedLen(0),
_vars(0)
{
Expand Down

0 comments on commit b7b6b73

Please sign in to comment.