Friday, June 22, 2012

Tree Implementation with Iterators

There is no STL tree class, so you'll have to make your own. There are many choices to be made, with different trade offs. I'll share my own experiences here (for clarity, I will not write it as a template).

Let's say you want to represent mathematical expressions using syntax trees. The most obvious way is probably to start with:


class Node {
  Node* _parent; // one prefix underscore is not a reserved name
  std::vector<Node> _children;
};


However, mathematical expressions must consist of variables (x, y, z), constants (0, 1, 42), and functions/operators (+, - , *, /). Let's keep things simple and assume that we are only dealing with basic arithmetic, using integers with addition, subtraction, and multiplication. (But we allow for custom functions to be added later, so we make no assumption of having a binary tree. Hence we keep the vector representation for the children.)

The problem then is that the nodes will have different types (variables, constants, operators). So we need polymorphic types:


class Node {
  Node* _parent;
  std::vector<Node*> _children; // use vector of pointers
};


We add some typedefs and easy members:


class Node {
  public:
  typedef std::vector<Node*>::iterator arg_iterator;
  typedef std::vector<Node*>::const_iterator const_arg_iterator;

  typedef std::vector<Node*>::size_type arity_type;


  // Default Constructor
  Node() : _parent(nullptr) { }
  // Construct node with one child
  Node(Node* n) : _parent(nullptr) { add_child(n); }

  // Construct node with two children
  Node(Node* n1, Node* n2) : _parent(nullptr) { add_child(n1); add_child(n2); }
  // Destructor
  ~Node() { std::for_each(arg_begin(),arg_end(),[&](Node* n){ delete n; }); }


  // Evaluate expression
  virtual int evaluate() const = 0; // must be overriden
  // Get arity
  arity_type arity() const { return _children.size(); }
  bool is_leaf() const { return _children.empty(); }
  // Set/Get parent
  Node*& parent() { return _parent; }
  const Node* parent() const { return _parent; }
  // Iterator access to arguments
  arg_iterator arg_begin() { return _children.begin(); }
  const_arg_iterator arg_begin() const { return _children.begin(); }
  arg_iterator arg_end() { return _children.end(); }

  const_arg_iterator arg_end() const { return _children.end(); }
  // Erase argument (use const_arg_iterator if your compiler properly supports C++11)
  arg_iterator  arg_erase(arg_iterator i) { return _children.erase(i); }
  // Get argument
  Node* arg_first() { return _children.front(); }
  Node* arg_last() { return _children.back(); }
  Node* arg(arity_type i) { return _children[i]; }


  // Print subtree
  virtual void print(std::ostream&) const = 0;



We define a member function to add a child to the right:


void Node::add_child(Node* n)
{
  n->remove_parent_ptr(); // if n has a parent, disconnect it
  n->parent() = this;
  _children.push_back(n);
}



void Node::remove_parent_ptr()
{
  if (parent()) {
    auto i = parent()->arg_begin();
    while (*i != this) ++i; // Must be found, don't check arg_end()
    parent()->arg_erase(i);
  }
}


With this we can define our derived classes:


class Constant : public Node {
  public:
    Constant(int v) : val(v) {}
    int evaluate() const { return val; }
  protected:
    int val;
};

// Exception class thrown when evaluating non-ground expressions
class non_ground {};


class Variable : public Node {
  public:
    Variable(const std::string& s) : name(s) {}
    int evaluate() const { throw non_ground(); }
    void print(std::ostream& os) const { os << _name; }
protected:
    std::string name;
};


The purpose of the Variable class is to represent expressions containing unassigned variables, like "3*x+1". We can then implement a member function evaluate(const std::map<std::string,int>&) which is identical to evaluate() except for the variable class, in which case it looks up the variable to value mapping and returns it (or throws non_ground() if no such mapping is found).


class Add : public Node {
  public:
    Add(Node* n1, Node* n2) : Node(n1,n2) { }
    int evaluate() const { return arg_first()->evaluate() + arg_last()->evaluate(); }
    void print(std::ostream& os) const {
      os << "(";
      arg_first()->print(os); os << "+"; arg_last()->print(os);
      os << ")";
    }
};



class Sub : public Node {
  public:
    Sub(Node* n) : Node(n) { }
    Sub(Node* n1, Node* n2) : Node(n1,n2) { }
    int evaluate() const {
      if (is_leaf()) {
 return -arg_first()->evaluate();
      } else {
 return arg_first()->evaluate() - arg_last()->evaluate();
      }
    }
    void print(std::ostream& os) const {
      if (is_leaf()) {
 os << "-"; arg_first()->print(os);
      } else {
 os << "(";
 arg_first()->print(os); os << "-"; arg_last()->print(os);
 os << ")";
      }
    }
};



class Mul : public Node {
  public:
    Mul(Node* n1, Node* n2) : Node(n1,n2) { }
    int evaluate() const { return arg_first()->evaluate() * arg_last()->evaluate(); }
 void print(std::ostream& os) const {
  arg_first()->print(os); os << "*"; arg_last()->print(os);
 }
};


We can test it:



#include <iostream>
#include "tree.h"

int main()
{
 using namespace std;
 Add root(
  new Mul(new Constant(3), new Constant(12)),
  new Constant(6));
 cout << "Evaluating " << root << " => " << root.evaluate() << "\n"; 

 Mul root2(
  new Constant(3),
  new Add(new Constant(12), new Constant(6)));
 cout << "Evaluating " << root2 << " => " << root2.evaluate() << "\n";
}


Output is:

Evaluating (3*12+6) => 42
Evaluating 3*(12+6) => 54

Finally, we add iterators to traverse the tree depth first (similar techniques can be used to create breadth first iterators or even best-first by supplying a lambda).

There's two ways we can implement iterators: either we use a stack containing the expanded nodes, or we "manually" backtrack the tree. The first method leads to somewhat faster traversal in theory, since getting to the next node is O(1). On the other hand, we need to add all expanded nodes to the stack, and whenever we copy an iterator, we need to copy the entire stack.

STL algorithms are designed around an assumption of fast copying of iterators, so we pick the lighter iterator approach (manually backtracking).

We may start with the necessary typedefs:



class node_citerator {
        const Node* nptr; // pointer to node
public:
 typedef std::bidirectional_iterator_tag iterator_category;
 typedef Node value_type;
 typedef Node* pointer;
 typedef Node& reference;
 typedef ptrdiff_t difference_type;

        // Construct iterator from pointer to node
 node_citerator(const Node* fp) : nptr(fp) { }



        // Compare two iterators


 bool operator==(const node_citerator& di) const { return nptr == di.nptr; }
 bool operator!=(const node_citerator& di) const { return !(*this == di); }


        // Dereferencing

 const Node& operator*() const { return *nptr;  }
 const Node* operator->() const { return nptr; }


Here I have chosen to dereference directly to Node, rather than a pointer to node. This makes the iterator interface easier to use with STL algorithms (without the need to invoke lambdas to dereference all the time).


Our begin() and end() would then look as follows:


inline Node::const_iterator Node::begin() const
{
 return const_iterator(this);
}

inline Node::const_iterator Node::end() const
{
 return const_iterator(nullptr);
}


There is however a potential problem with this approach: it will not work if we wish to traverse a proper subtree, as opposed to the entire tree. This is because we are not storing the next nodes to visit, so when backtracking we essentially climb assuming the first node was the root (i.e. the node which has a nullptr parent pointer). As long as we only call begin() and end() from the root node, no surprises emerges. But consider:

Add root(new Mul(new Constant(3), new Constant(12)), new Constant(6)); std::for_each(root.arg_first()->begin(), root.arg_first()->end(), my_lambda);

This will not work as expected, since there is no clean way to implement the iterator's increment operator so that it knows where it started (i.e. that begin() and end() was called from root.arg_first(), and not root).

So we need to store the root node pointer too:


class node_citerator {
        const Node* nptr; // pointer to node
        const Node* root; // pointer to subtree root
public:
        // Same typedefs as before
        // Construct iterator from pointer to node, remember caller node
 node_citerator(const Node* fp, const Node* r) : nptr(fp), root(r) { }
        // Comparisons may be defined as before, using two iterators with different caller nodes
        // may simply be considered undefined behavior as it doesn't make any sense


Now our begin() and end() tells the iterator which node the call came from using the second argument:



inline Node::const_iterator Node::begin() const
{
 return const_iterator(this,this);
}

inline Node::const_iterator Node::end() const
{
 return const_iterator(nullptr,this);
}



Prefix increment would then look like this:



// Prefix increment
 node_citerator& operator++() {
  if (nptr->is_leaf()) {
   // This is a leaf node, so we need to climb up
   for (;;) {
    if (nptr == root) {
     nptr = nullptr;
     break;
    }
    // note: if nptr is not root, it must have a parent
    const Node* par = nptr->parent();
    // Determine which child we are
    auto next = par->arg_begin();
    for ( ; *next != nptr ; ++next); // no bounds check: nptr is in array
    ++next; // bring to next
    if (next != par->arg_end()) {
     // Branch down to next child
     nptr = *next;
     break;
    } else {
     // We were the last child of parent node, so continue up
     nptr = par;
    }
   }
  } else {
   // Not a leaf node, so move down one step to the left
   nptr = nptr->arg_first();
  }
  return *this;
 }



Prefix decrement may be defined similarly:



// Prefix decrement
 node_citerator operator--() {
  if (nptr) {
   // note: -- on first element is undefined => we may safely move up if not left
   if (nptr == nptr->parent()->arg_first()) {
    // nptr is first child => move up
    nptr = nptr->parent();
   } else {
    // nptr is not first child => move up one step, then traverse down
    // find pointer from parent
    auto prev = --nptr->parent()->arg_end();
    for ( ; *prev != nptr; --prev);
    --prev; // previous from nptr (prev can't be argv.front())
    nptr = *prev;
    // Now traverse down right most
    while (!nptr->is_leaf()) nptr = nptr->arg_last();
   }
  } else {
   // nptr at end, so we need to use root to get to last element
   for (nptr = root; !nptr->is_leaf(); ) {
    nptr = nptr->arg_last();
   }
  }
  return *this;
 }

No comments:

Post a Comment