// Copyright (C) 2011 Microsoft Research
// CM Wintersteiger, 2011

#ifdef _DEBUG
  #define _CRTDBG_MAP_ALLOC
  #define _CRTDBG_MAP_ALLOC_NEW
  #include <stdlib.h>
  #include <crtdbg.h>
#endif

#include <cassert>
#include <iostream>

#include <omp.h>

#include "minisat1p.h"
#include "decomposition_batch.h"
#include "decomposition_cycle.h"
#include "decomposition_rnd.h"
#include "interpolator_m.h"
#include "interpolator_im.h"
#include "interpolator_p.h"

#include "desat_old.h"

DeSAT_old::DeSAT_old(ExpressionManager &m) :
  SATSolver(m),
  globalTime(0),
  partitionsTime(0),
  importTime(0),
  maxVar(0),
  interpolants_imported(0),
  rounds(0)
{  
  init(omp_get_max_threads());
}

DeSAT_old::DeSAT_old(ExpressionManager &m, unsigned n) :
  SATSolver(m),
  globalTime(0),
  partitionsTime(0),
  importTime(0),
  maxVar(0),
  interpolants_imported(0)
{  
  init(n);
}

void DeSAT_old::init(unsigned n)
{
  cores=n;
  s.resize(cores, NULL);
  ms.resize(cores, NULL);
  if (cores==1)
    s[0] = new MiniSAT_1p(m, false);
  else
  {
    for (unsigned i=0; i<cores; i++)
    {
      ms[i] = new ExpressionManager();
      s[i] = new MiniSAT_1p(*ms[i], true);    
    }
  }    
  translation.setLanguages(cores);
  d = new BatchDecomposition();
  //d = new CycleDecomposition();
  //d = new RandomDecompositon();
  d->setPartitions(cores);    
  assumptions.resize(cores);  
  interpolants.resize(cores);
  models.resize(cores);
  new_models.resize(cores);
  model_vars.resize(cores);
  for (unsigned i=0; i<cores; i++)  
    interpolants[i].resize(cores, m.mkNil());  
  crossModelSolver = new MiniSAT_1p(m, false);
  omp_set_num_threads(cores);
}

DeSAT_old::~DeSAT_old(void)
{
  for (unsigned i=0; i<s.size(); i++)
  {    
    if (s[i]) delete s[i];
    if (ms[i]) delete ms[i];
  }
  s.clear();
  if (d) delete d;
  if (crossModelSolver) delete crossModelSolver;
}

void DeSAT_old::setVerbose(int v)
{
  for (unsigned i=0; i<s.size(); i++)
    s[i]->setVerbose(v);
  crossModelSolver->setVerbose(v);
  SATSolver::setVerbose(v);
}

bool DeSAT_old::addClause(const std::vector<signed> &literals)
{
  unsigned w = d->where(literals);
  // print("Next is going to %d\n", w);

  temp.resize(literals.size());

  for (unsigned i=0; i<literals.size(); i++)
  {
    const signed &l = literals[i];
    if (!translation.hasVariable(l, w))
    {
      signed n = s[w]->addVar();
      translation.insert(w, l, n);
    }

    temp[i] = translation.to(l, w);
  }
  
  return s[w]->addClause(temp);
}

bool DeSAT_old::addUnit(signed l)
{
  throw std::exception("NYI: addUnit");
}

void DeSAT_old::setVariableMax(unsigned n)
{ 
  if (n<maxVar) return;

  maxVar = n;
  d->setVariableMax(n);
  translation.setMaxVariable(n);
  crossModelSolver->setVariableMax(n);  
}

void DeSAT_old::setClauseMax(unsigned n)
{ 
  if (n<cores)
  {
    std::cout << "Warning: less clauses than cores; reducing the number of cores to " << n << std::endl;
    cores = n;
  }
  d->setClauseMax(n);
}

signed DeSAT_old::addVar(void)
{
  throw std::exception("NYI: adding variables");
  return 0;
}

unsigned DeSAT_old::numClauses(void) const
{
  unsigned res = 0;
  for (unsigned i=0; i<s.size(); i++)
  {
    res += s[i]->numClauses();
  }
  return res;
}

unsigned DeSAT_old::numVars(void) const
{
  return maxVar+1;
}

bool DeSAT_old::solve(const std::vector<signed> &assumptions)
{
  throw std::exception("NYI: solving under assumptions.");
  return false;
}

bool DeSAT_old::solve(void)
{ 
  printf("Solving with %d variables, %d clauses and %d partitions.\n", maxVar, numClauses(), cores);  

  if (cores==1)
    return s[0]->solve();

  showDistribution();

  rounds = 0;
  ModelValue res = M_UNDEF;

  while (res==M_UNDEF)
  {
    if (verbosity==2)
    {
      print("Press any key to continue...");
      ::getchar();
    }

    rounds++;

    for (unsigned i=0; i<79; i++)
      print("=");
    print("\nRound %d\n", rounds);
    for (unsigned i=0; i<79; i++)
      print("-");
    print("\n");    

    if (solvePartitions())
      res = M_FALSE; // UNSAT
    else if (reconcile())
      res = M_TRUE; // SAT    
  }

  return res==M_TRUE;
}

ModelValue DeSAT_old::get(signed l) const
{
  throw std::exception("NYI: model extraction");
  return M_UNDEF;
}

Expression DeSAT_old::getInterpolant(const std::vector<signed> &A)
{
  throw std::exception("NYI: interpolant extraction");
  return m.mkTrue();
}

Expression DeSAT_old::getModel(void) const
{
  throw std::exception("NYI: model extraction");
  return m.mkFalse();
}

Expression DeSAT_old::getModel(unsigned i) const
{
  std::vector<Expression> children;

  for (signed v=1; v<=(signed)maxVar; v++)
  {
    if (translation.hasVariable(v, i) && !translation.isExclusive(v, i))
    {
      ModelValue mv = s[i]->get(translation.to(v, i));    
      switch(mv)
      {
      case M_TRUE: children.push_back(m.mkLiteral(v)); break;
      case M_FALSE: children.push_back(m.mkLiteral(-v)); break;
      default:
        /* ignore */ ;
      }
    }
  }

  if (children.size()==0)
    return m.mkTrue();
  else if (children.size()==1)
    return m.mkLiteral(m.getLiteral(children[0]));
  else
  {
    Expression r = m.mkAnd(children[0], children[1]);
    for (unsigned i=2; i<children.size(); i++)
      r = m.mkAnd(r, children[i]);
    return r;
  }
}

bool DeSAT_old::addConstraint(const Expression &e)
{ 
  throw std::exception("NYI: addConstraint");
  return false;
}

signed DeSAT_old::addExtension(const Expression &e)
{
  throw std::exception("NYI: addExtension");
  return 0;
}

void DeSAT_old::showDistribution(void) const
{
  /*if (verbosity==0)
    return;*/

  if (cores>16) return;

  printf("Distribution:");
  for (unsigned i=0; i<cores; i++)
    printf(" %d", s[i]->numClauses());
  printf("\n");

  std::vector< std::vector< unsigned > > shared_count(cores);
  unsigned total = maxVar;

  for (unsigned v=1; v<maxVar; v++)
  {
    bool exclusive=true;

    for (unsigned i=0; i<cores; i++)
    {
      for (unsigned j=0; j<cores; j++)
      {
        if (i==j) continue;
        if (shared_count[i].size()!=cores)
          shared_count[i].resize(cores, 0);
      
        if (translation.isShared(v, i, j))
        {          
          exclusive=false;
          shared_count[i][j]++;
        }
      }
    }

    if (exclusive) total--;
  }

  printf("Non-exclusive variables: %d (%.2f%%)\n", total, 100 * total/(double)maxVar);

  printf("Shared variables (%%):\n");  
  printf("Core:");
  for (unsigned i=0; i<cores; i++)
    printf("  %02d", i);
  printf("\n");

  for (unsigned i=0; i<cores; i++)
  {
    printf("%02d:  ", i);
    for (unsigned j=0; j<cores; j++)
    {
      printf(" % 3.0f", (100 * shared_count[i][j]) / (double) maxVar);
    }
    printf("\n");
  }

}

bool DeSAT_old::solvePartitions(void)
{
  print("Solving phase.\n");
  bool all_sat = true;
  std::vector<bool> results(cores, false);

  clock_t before = clock();  

  if (sequential)
  {
    for (int tid=0; tid<(signed)cores; tid++)
    {
      bool r = s[tid]->solve();
      if (r) 
        models[tid].push_back(getModel(tid));
      results[tid] = r;
    }
  }
  else
  {
    #pragma omp for
    for (int tid=0; tid<(signed)cores; tid++)
    {
      bool r = s[tid]->solve();
      if (r) 
        models[tid].push_back(getModel(tid));

      #pragma omp critical
      {
        results[tid] = r;
      }
    }
  }

  partitionsTime += clock()-before;

  print("Satisfaction: ");
  for (unsigned i=0; i<results.size(); i++)
  {
    if (results[i]) 
    {      
      print("S");
      assert(models[i].size()>0);      
    }
    else
    {
      print("U");
      if (models[i].size()==0)
        all_sat = false; // trivial UNSAT
    }
  }
  print("\n");

  if (all_sat)
    showModels();

  return !all_sat; // true means solved (UNSAT)
}

bool DeSAT_old::reconcile(void)
{
  print("Reconciliation phase.\n");
  // At this point, everybody has a model. 

  bool res = false;

  if (crossModelCheck())
  {
    // save model?
    print("... succeeded!\n");
    res = true;
  }
  else
  {
    print("... failed!\n");

    if (interpolationMode != NONE )
    {
      if (sequential)
      {
        for (unsigned tid=0; tid<cores; tid++)        
          for (unsigned j=0; j<cores; j++)
            reconcile(tid, j);        
      }
      else
      {        
        #pragma omp parallel num_threads(cores) default(shared)
        {
          int tid=omp_get_thread_num();
          for (unsigned j=0; j<cores; j++)
            reconcile(tid, j);
        }        
      }
  
      if (sequential)
      {
        for (unsigned tid=0; tid<cores; tid++)
          importInterpolants(tid);
      }
      else
      {
        #pragma omp parallel num_threads(cores) default(shared)      
        {
          importInterpolants(omp_get_thread_num());
        }
      }
    }

    if (sequential)
    {
      for (unsigned tid=0; tid<cores; tid++)
      {       
        if (models[tid].size()>0)
        {
          Expression c = models[tid].back();          
          std::vector<signed> temp;
          std::vector<Expression> e_temp;
          assert(m.isCube(c));

          temp.clear();
          m.getLiterals(c, temp);
          e_temp.resize(temp.size());
          for (unsigned i=0; i<temp.size(); i++)
            e_temp[i] = ms[tid]->mkLiteral(- translation.to(temp[i], tid));

          CExpression t = ms[tid]->mkOr(e_temp);

          ((MiniSAT_1p*)s[tid])->clearNewClauses();
          s[tid]->addConstraint(t);
          ((MiniSAT_1p*)s[tid])->addNewClauses();
        }
      }
    }
    else
    {
      #pragma omp parallel num_threads(cores) default(shared)      
      {
        int tid=omp_get_thread_num();
        if (models[tid].size()>0)
        {
          Expression c = models[tid].back();
          if (models[tid].size()>0)
          {
            Expression c = models[tid].back();            
            std::vector<signed> temp;
            std::vector<Expression> e_temp;
            assert(m.isCube(c));

            temp.clear();
            
            #pragma omp critical
            { m.getLiterals(c, temp); }
            e_temp.resize(temp.size());
            for (unsigned i=0; i<temp.size(); i++)
              e_temp[i] = ms[tid]->mkLiteral(- translation.to(temp[i], tid));

            CExpression t = ms[tid]->mkOr(e_temp);

            ((MiniSAT_1p*)s[tid])->clearNewClauses();
            s[tid]->addConstraint(t);
            ((MiniSAT_1p*)s[tid])->addNewClauses();
          }
        }
      }
    }
  }

  return res; // true means solved (SAT)
}

bool DeSAT_old::crossModelCheck(void)
{
  print ("Cross-Model Check ");  

  clock_t before = clock();
  temp.clear();

  for (unsigned i=0; i<cores; i++)
  {
    for (unsigned j=0; j<new_models[i].size(); j++)
      models[i].push_back(new_models[i][j]);
    new_models[i].clear();
  }

  for (unsigned i=0; i<cores; i++)
  {
    assert(models[i].size() > 0);
    
    while (model_vars[i].size() < models[i].size())
    {
      size_t j = model_vars[i].size();
      CExpression model = models[i][j];
      unsigned oldvar = (j==0) ? 0 : model_vars[i][j-1];
      
      ((MiniSAT_1p*)crossModelSolver)->clearNewClauses();
      
      if (m.isTrue(model))
      {
        model_vars[i].push_back(oldvar);
      }
      else if (oldvar!=0)
      {        
        signed e=crossModelSolver->addExtension(model);
        Expression old = m.mkLiteral(oldvar);
        Expression nm = m.mkLiteral(e);
        CExpression t = m.mkOr(old, nm);
        model_vars[i].push_back(crossModelSolver->addExtension(t));        
      }
      else
        model_vars[i].push_back(crossModelSolver->addExtension(model));       

      ((MiniSAT_1p*)crossModelSolver)->addNewClauses();
    }
    
    assert(model_vars[i].size()>0);
    temp.push_back(model_vars[i].back());
  }
  
  bool r = crossModelSolver->solve(temp);

  globalTime += clock() - before;

  return r;
}

bool DeSAT_old::reconcile(unsigned x, unsigned y)
{
  if (x==y) return true;
  if (models[y].size()==0) return true;  

  //bool models_agree=true;
  //for (unsigned v=1; v<=maxVar; v++)
  //  if (translation.isShared(v, x, y) && 
  //      s[x]->get(translation.to(v, x)) != s[y]->get(translation.to(v, y)))
  //      models_agree = false;
  //
  //if (models_agree)
  //{
  //  // print("%d/%d: Models agree.\n", x, y);
  //  return true;
  //}
  //else
  {
    print("%d/%d: Finding interpolant for %d -> -alpha(%d)...\n", x, y, x, y);
  }
    
  /*print(" SV:");
  for (unsigned v=1; v<=maxVar; v++)
  {
    if (!translation.hasVariable(v, x))
      print(" ");
    else 
    { 
      if (translation.isShared(v, x, y))
        print("+");
      else
        print("-");
    }
  }
  print("\n");*/
      
  Expression last = models[y].back();
  if (m.isTrue(last)) return true;

  clock_t before = clock();

  assumptions[x].clear(); // y's assignment in terms of x's variables.

  if (m.isAnd(last))
  {
    std::vector<CExpression> stack;     
    stack.push_back(last);
          
    while (!stack.empty())      
    {
      CExpression q = stack.back();
      stack.pop_back();
      assert(m.isAnd(q));

      for (unsigned i=0; i<m.nChildren(q); i++)
      {
        CExpression child = m.getChild(q, i);
        if(m.isLiteral(child))
        {
          if (translation.isShared(m.getLiteral(child), x, y))
            assumptions[x].push_back(translation.to(m.getLiteral(child), x));
        }
        else
          stack.push_back(child);
      }
    }
  }
  else if (m.isLiteral(last))
  {
    if (translation.isShared(m.getLiteral(last), x, y))
        assumptions[x].push_back(translation.to(m.getLiteral(last), x));
  }
  else
    throw std::exception("Unexpected expression type.");

  //print("ASM:", x);
  //for (unsigned i=0; i<assumptions[x].size(); i++)
  //  print(" %d", translation.from(assumptions[x][i], x));
  //print("\n");
    
  Interpolator *i;
  switch(interpolationMode)
  {
  case NONE: throw std::exception("No interpolator selected"); break;
  case MCMILLAN: i = new InterpolatorM(*ms[x], x, y, translation); break;
  case INVERSE_MCMILLAN: i = new InterpolatorIM(*ms[x], x, y, translation); break;
  case PUDLAK: i = new InterpolatorP(*ms[x], x, y, translation); break;
  default: 
    throw std::exception("Unknown interpolator selected"); break;
  }

  s[x]->setInterpolator(i);
  Expression itp = s[x]->getInterpolant(assumptions[x]);  
  s[x]->setInterpolator(NULL);
  delete i;
 
  // itp is in terms of y's variables.

  if (ms[x]->isTrue(itp))
  {
    #pragma omp critical
    {
      Expression m = getModel(x);
      new_models[x].push_back(m);
    }
  }
  else if (verbosity>0)
  {
    #pragma omp critical
    {
      std::string t = m.toString(itp);
      print("ITP: %s\n", t.c_str());
    }
  }
    
  #pragma omp critical
  {
    interpolants[x][y] = itp;
  }

  importTime += clock() - before;

  return false;
}

bool DeSAT_old::importInterpolants(unsigned x)
{
  bool did_something = false;

  clock_t before = clock();

  for (unsigned i=0; i<cores; i++)
  {
    if (i==x) continue;

    Expression itp = interpolants[i][x];
    if (ms[i]->isNil(itp)) continue;

    if (!ms[i]->isTrue(itp))
    {
      if (verbosity>0)
      {
        #pragma omp critical
        {
          std::string t = ms[i]->toString(itp);
          print("#%02d: Importing %s\n", x, t.c_str());
        }
      }
      interpolants_imported++;
      CExpression q = ms[x]->duplicate(itp, *ms[i]);
      ((MiniSAT_1p*)s[x])->clearNewClauses();
      s[x]->addConstraint(q);
      ((MiniSAT_1p*)s[x])->addNewClauses();
      did_something = true;
    }

    #pragma omp critical
    {
      interpolants[i][x]=m.mkNil();
    }
  }

  importTime += clock() - before;

  return did_something;
}

void DeSAT_old::showModels(void) const
{
  if (verbosity==0)
    return;

  print("Models:\n");
  for (unsigned i=0; i<cores; i++)
  {
    print("#%02d: ", i);
    print("(OR \n");
    for (unsigned j=0; j<models[i].size(); j++)
    {      
      #pragma omp critical
      {
        std::string t=m.toString(models[i][j]);
        print("%s\n", t.c_str());
      }
    }    
    print(")\n");
  }
}