// Copyright (C) 2011 Microsoft Research
// CM Wintersteiger, 2011

#ifndef _MINISAT1P_H_
#define _MINISAT1P_H_

#include <vector>
#include <map>

#include <Solver.h>
#include <Proof.h>

#include "satsolver.h"
#include "interpolator.h"

#include "expression.h"
#include "ExpressionManager.h"

class MiniSAT_1p : public SATSolver, public Solver
{
public:
  typedef enum { NORMAL, CHECKING, LIFTING } ModelMode;
  typedef enum { TSEITIN_EXTENSION, PLAISTED_GREENBAUM_EXTENSION, EXPLOSION } ConstraintMode;
  
  MiniSAT_1p(ExpressionManager &m, bool proof=true, ModelMode mm=NORMAL);
  virtual ~MiniSAT_1p(void);

  virtual bool addClause(const std::vector<signed> &literals);
  virtual bool addUnit(signed l);
  virtual void setVariableMax(unsigned n);
  virtual void setClauseMax(unsigned n);

  virtual signed addVar(void);
  virtual unsigned numClauses(void) const;  
  virtual unsigned numVars(void) const;

  virtual bool solve(void);
  virtual bool solve(const std::vector<signed> &assumptions);

  inline virtual ModelValue get(signed l) const;
  virtual Expression getInterpolant(const std::vector<signed> &beta);
  virtual Expression getModel(void) const;

  virtual bool addConstraint(CExpression &e);
  virtual signed addExtension(CExpression &e);

  virtual void setVerbose(int v);

  void getConflict(std::vector<signed> &out)
  {
    out.clear();
    for (int i=0; i<conflict.size(); i++)
    {
      const Lit &l = conflict[i];
      signed x = sign(l) ? -var(l) : var(l);
      out.push_back(x);
    }
  }  

  bool addLearntClause(const std::vector<signed> &literals);
  bool addHalfLearntClause(const std::vector<signed> &literals);

private:
  vec<Lit> temp_lits;

protected:
	typedef enum { ORIGINAL_CLAUSE = 0, LEARNT_CLAUSE, HALFLEARNT_CLAUSE } ClauseType;
  

  vec<Lit> assumptions;
  std::map<unsigned, Expression> interpolants;
  std::vector<unsigned> interpolate_stack;  
  ConstraintMode constraintMode;
  
  class InterpolationTraverser : public ProofTraverser {
  public:
    InterpolationTraverser(void);
    virtual void root(const vec<Lit>& c);
    virtual void chain(const vec<ClauseId>& cs, const vec<Lit>& xs);
    virtual void deleted(ClauseId c);
    virtual void done(ClauseId e);
    virtual ~InterpolationTraverser();   
   
    class ResolutionChainInfo 
    {
    public:
      vec<ClauseId> cs;
      std::vector<signed> xs;
#ifdef _DEBUG
      vec<Lit> resolvent;
#endif
	};
    vec<ResolutionChainInfo> resChainInfo;

  protected:
    void resolve(vec<Lit>& main, vec<Lit>& other, Lit x) const;

  };  

  InterpolationTraverser iTraverser;

  Expression interpolate(void);
  Expression interpolate(int cid, const vec<Lit> &assumptions);

  typedef std::map<CExpression, Lit> ClausifyCache;
  typedef std::map<CExpression, vec<vec<Lit> > > ExplosionCache;
  ClausifyCache clausifyCache;  
  ExplosionCache explosionCache;  
  std::vector<CExpression> clausifyCNFStack;
  std::vector<CExpression> clausifyExtendStack;
  std::vector<CExpression> clausifyExplodeStack;
  std::vector<bool> assumed;
  std::vector<signed> tmp_unit;

  Lit clausify_clause(CExpression &e);
  Lit clausify_cube(CExpression &e);
  Lit clausify_cnf(CExpression &e);
  Lit extend(CExpression &e);
  Lit explode(CExpression &e);

  virtual bool resolve_conflict(Clause *confl) { return Solver::resolve_conflict(confl); }

  bool addClause(const vec<Lit> &literals, int type)
  { 

	  switch(type)
	  {
	  case ORIGINAL_CLAUSE: //ORIGINAL
		  Solver::addClause(literals);
		  break;
	  case LEARNT_CLAUSE: //LEARNT
		  Solver::addLearntClause(literals);
		  break;
	  case HALFLEARNT_CLAUSE: //HALFLEARNT
		  Solver::addHalfLearntClause(literals);
		  break;
	  default: 
		  throw std::exception("unknown clause type");
	  }
	  
	  return Solver::okay();
  }

  bool addUnit(Lit l)
  {
    Solver::addUnit(l);
    return Solver::okay();
  }
  
  void getLits(CExpression &e, vec<Lit> &lits);

  vec<vec<Lit> > new_clauses;

public:

  void removeLastToClause(std::vector<signed> &clause);

  void clearNewClauses(void) { new_clauses.clear(); }
  void addNewClauses(void);
  void clearAssumed(void) { assumed.clear(); }
  void addNewClauses(vec<vec<Lit> > interpolant_clauses);
  vec<vec<Lit> > getNewClauses(void) { return new_clauses; }
  //void setTimeout(int seconds) { Solver::setTimeout(seconds); }
  void splitDB(std::vector<signed> &sendBuffer);
  void removeLast(std::vector<signed> &sendBuffer);
  void resetProof(void){
  	 //delete proof;
	  if(proof != NULL) delete proof;
	  InterpolationTraverser iTraverser;
	  Solver::proof = new Proof(iTraverser);  
  }
  void deleteProof(void){
	if(proof != NULL) delete proof;
	Solver::proof = NULL;
  }
  //void getUsedVariables(std::set<int> &vars);
};


inline Lit literalToLit(signed x)
{
  bool sgn = (x<0);
  int v = (sgn) ? -x : x;
  return Lit(v, sgn);
}

inline signed litToLiteral(const Lit &x)
{
  bool sgn = sign(x);
  signed v = var(x);
  return (sgn) ? -v : v;
}

#endif