Pages

Search This Blog

January 06, 2013

Convert infix equation to binary tree and solve

Logic inspired by Dijkstra's shunting yard algorithm

#include <iostream>
#include <iomanip>
#include <stack>
#include <conio.h>

#define is_operator(x) ((x) == \'+\' || (x) == \'-\' || (x) == \'*\' || (x) == \'/\')

/*
- read a string representing any equation and convert that into a 
binary tree with operators at the roots
- then solve the binary tree using recursive methods to arrive at the 
value of the equation

***SEEME*** - works only with equations having single digit numbers
***SEEME*** - equations are well paranthesized so no need of associativity
*/

struct Node
{
    Node(void *x):data(x), left(0), right(0){}
    void *data;
    Node *left, *right;
};

/*
utils
*/

/*
prints the binary tree in pre-order form
*/
void print_tree_helper(Node *root)
{    
    // cast it appropriately 
    char c = *(char *)root->data;
    
    if(is_operator(c))    
        std::cout << c << " ";
    else
        std::cout << *(int *)root->data << " ";

    if(root->left != NULL)
        print_tree_helper(root->left);
        
    if(root->right != NULL)
        print_tree_helper(root->right);    
}

void print_tree(Node *root)
{
    std::cout << "binary AST in pre-order " << std::endl;
    print_tree_helper(root);
    std::cout << std::endl;
}

/*
frees the memory occupied by the binary tree rooted at root

-- post order traversal is the ideal way to clean up a tree
*/
void clean_up(Node *root)
{
    if(root->left != NULL)
        clean_up(root->left);
    if(root->right != NULL)
        clean_up(root->right);

    if(root->data)
    {
        char c = *(char *)root->data;
        if(is_operator(c))
            delete (char *)root->data;
        else
            delete (int *)root->data;
    }

    if(root)
        delete root;

    root = NULL;
}

/*
converts the passed equation into a binary tree and
returns the root of the tree
*/
Node * convert_eqn_to_bt(std::string eqn)
{
    std::stack<char *> operators;
    std::stack<Node *> nodes;

    char *o;
    int *x;
    Node *node1, *node2, *opr;
    for(size_t i = 0; i < eqn.length(); i++)
    {
        char c = eqn[i];
        switch(c)
            {
                case \'+\': case \'-\': case \'*\': case \'/\': case \'(\':
                    operators.push(new char(c));
                    break;
                case \'0\': case \'1\': case \'2\': case \'3\': case \'4\':
                case \'5\': case \'6\': case \'7\': case \'8\': case \'9\': 
                    x = new int((atoi((const char *)&c)));
                    nodes.push(new Node(x));
                    break;
                case \')\':
                    /*
                    inspired from dijkstra\'s shunting yard algorithm
                    - we have the operands so far pushed into nodes stack as nodes
                    - till we get \'(\', pop out operators, ideally we should have only 1
                    - pop out two operand/tree nodes
                    - construct a tree with operand as root and push into nodes stack
                    - pop \'(\' and repeat
                    */
                    
                    if(!nodes.empty())
                    {
                        node1 = nodes.top();
                        nodes.pop();    
                    }
                    
                    if(!nodes.empty())
                    {
                        node2 = nodes.top();
                        nodes.pop();
                    }
                                        
                    while(true)
                    {
                        if(operators.empty())
                        {
                            // fatal error - mis match in parens
                            fprintf(stderr, "fatal error in %d\\n", __LINE__);
                            return NULL;
                        }
                        
                        o = operators.top();
                        operators.pop();
                        
                        if(*o == \'(\')
                            break;
                            
                        opr = new Node(o);
                    }
                    
                    opr->right = node1;
                    opr->left = node2;
                    nodes.push(opr);    
                    break;
                default:
                    fprintf(stderr, "fatal error in %d\\n", __LINE__);
                    return NULL;
            }
    }
    
    while(true)
    {
        // check if we have remaining in operators
        if(operators.empty()) return nodes.top();
        
        o = operators.top();
        operators.pop();
        opr = new Node(o);

        if(!nodes.empty())
        {
            node1 = nodes.top();
            nodes.pop();
        }
        
        if(!nodes.empty())
        {
            node2 = nodes.top();
            nodes.pop();    
        }
        
        opr->right = node1;
        opr->left = node2;
        nodes.push(opr);
    }
    
    // error !
    return NULL;
}

/*
evaluates the binary AST rooted at root and
returns the result
*/
int evaluate_tree(Node *root)
{
    // cast it appropriately 
    char c = *(char *)root->data;
    int x = 0, y = 0;
    if(is_operator(c))
    {    
        x = evaluate_tree(root->left);
        y = evaluate_tree(root->right);
        
        switch(c)
        {
            case \'+\': return x+y;
            case \'-\': return x-y;
            case \'*\': return x*y;
            case \'/\': return x/y;
            default:
                fprintf(stderr, "fatal error in %d\\n", __LINE__);
                return 0;
        }
    }
    else
        return *(int *)root->data;
}

/*
test suite
*/
void testA()
{
    Node *root = convert_eqn_to_bt("((2+4)*(5+3))+((8*9)-9)-(5-(3*2))");
    assert(evaluate_tree(root) == 112);
    clean_up(root);
}

void testB()
{
    Node *root = convert_eqn_to_bt("((6*4)*7)+(((6/3)-1)-(9-(3*2)))");
    assert(evaluate_tree(root) == 166);
    clean_up(root);    
}

void testC()
{
    Node* root = convert_eqn_to_bt("4+(3+(2*(3-1)))");
    assert(evaluate_tree(root) == 11);
    clean_up(root);    
}

void tests()
{
    testA();
    testB();
    testC();
    
    std::cout << "all tests passed !" << std::endl;
}

int main()
{
    tests();
    getch();    
}

No comments:

Post a Comment