// Copyright (C) 2011 Microsoft Research
// CM Wintersteiger, 2011

#ifdef _DEBUG
  #define _CRTDBG_MAP_ALLOC
  #define _CRTDBG_MAP_ALLOC_NEW
  #include <stdlib.h>
  #include <crtdbg.h>
#endif

#include <stdlib.h>

#include <sstream>
#include <cassert>
#include <iostream>
#include <algorithm>

#include <omp.h>

#include "minisat1p.h"
#include "decomposition_batch.h"
#include "decomposition_cycle.h"
#include "decomposition_rnd.h"
#include "decomposition_variables.h"
#include "interpolator_m.h"
#include "interpolator_p.h"
#include "interpolator_im.h"

#include "desat.h"

#include <Windows.h>

#include <psapi.h>
#pragma comment(lib, "psapi.lib") //added


DeSAT::DeSAT(ExpressionManager &m, unsigned partitions) :
  SATSolver(m),
  globalTime(0),
  partitionsTime(0),
  importTime(0),
  lastIterationTime(0),
  maxVar(0),  
  interpolants_imported(0),
  rounds(0),
  d(NULL),
  globalSolver(NULL),
  early_stop(false),
  solutions_imported(0),
  all_sat_found(0)
{  
  init(partitions, omp_get_num_procs());
}

DeSAT::DeSAT(ExpressionManager &m, unsigned partitions, unsigned cs) :
  SATSolver(m),
  globalTime(0),
  partitionsTime(0),
  importTime(0),
  lastIterationTime(0),
  maxVar(0),
  interpolants_imported(0),
  d(NULL),
  globalSolver(NULL),
  early_stop(false),
  solutions_imported(0),
  all_sat_found(0)
{  
  init(partitions, cs);
}

void DeSAT::init(unsigned ps, unsigned cs)
{
  n_partitions=ps;
  n_cores=cs;    

  if (n_partitions==1)
    globalSolver = new MiniSAT_1p(m, false);      
  else
  {    
    d = new BatchDecomposition();
    //d = new CycleDecomposition();
    //d = new RandomDecompositon();
    //d = new VariableDecomposition();
    d->setPartitions(n_partitions);
    globalSolver = new MiniSAT_1p(m, false, MiniSAT_1p::LIFTING);    
    assumptions.resize(n_partitions);
    interpolants.resize(n_partitions, m.mkNil());    
    for (unsigned i=0; i<n_partitions; i++)
      partitions.push_back(new Partition(*(new ExpressionManager()), sharedVariables, i, verbosity));
  }
  omp_set_num_threads(n_cores);  
}

DeSAT::~DeSAT(void)
{  
  if (n_partitions > 1)
	for (unsigned i=0; i<n_partitions; i++) {
		delete &partitions[i]->em();
		delete partitions[i];
	}
  if (d) delete d;
  if (globalSolver) delete globalSolver;
}

void DeSAT::setVerbose(int v)
{
  for (unsigned i=0; i<partitions.size(); i++)
    partitions[i]->setVerbose(v);
  if (globalSolver) globalSolver->setVerbose(v);
  SATSolver::setVerbose(v);
}

void DeSAT::setInterpolator(InterpolationMode i)
{  
  interpolationMode = i;
  if (n_partitions<=1) return;

  for (unsigned j=0; j<n_partitions; j++)
    partitions[j]->setInterpolationMode(i);
}

bool DeSAT::addClause(const std::vector<signed> &literals)
{
  assert(n_cores!=0);

  if (n_partitions==1)
    return globalSolver->addClause(literals);

  unsigned w = d->where(literals);
  //std::cout << "Next is going to " << w << std::endl;

  if (w==-1)
  {
    assert(false); // This should only happen with a VariableDecomposition!
    return globalSolver->addClause(literals);
  }
  else
    return partitions[w]->addClause(literals);
}

bool DeSAT::addUnit(signed l)
{
  throw std::exception("NYI: addUnit");
}

void DeSAT::setVariableMax(unsigned n)
{ 
  if (n<maxVar) return;

  maxVar = n;

  assert(n_cores!=0);

  globalSolver->setVariableMax(n);  

  if (n_partitions>1)
  {
    d->setVariableMax(n);
    for (unsigned i=0; i<partitions.size(); i++)
      partitions[i]->setVariableMax(n);
  }
}

void DeSAT::setClauseMax(unsigned n)
{ 
  //if (n<n_partitions)
  //{
  //  std::cout << "Warning: less clauses than cores; reducing the number of partitions to " << n << std::endl;
  //  n_partitions = n;
  //  d->setPartitions(n_partitions);
  //  while (partitions.size() > n)
  //  {
  //    delete partitions.back();
  //    partitions.pop_back();
  //  }
  //}

  if (n_partitions>1)
  {
    d->setClauseMax(n);
    for (unsigned i=0; i<partitions.size(); i++)
      partitions[i]->setClauseMax(n);
  }
}

signed DeSAT::addVar(void)
{
  throw std::exception("NYI: adding variables");
  return 0;
}

unsigned DeSAT::numClauses(void) const
{
  unsigned res = 0;
  for (unsigned i=0; i<partitions.size(); i++)
    res += partitions[i]->numClauses();
  res += globalSolver->numClauses();
  return res;
}

unsigned DeSAT::numVars(void) const
{
  return maxVar+1;
}

bool DeSAT::solve(const std::vector<signed> &assumptions)
{
  throw std::exception("NYI: solving under assumptions.");
  return false;
}

ModelValue DeSAT::get(signed l) const
{
  throw std::exception("NYI: model extraction");
  return M_UNDEF;
}

Expression DeSAT::getInterpolant(const std::vector<signed> &A)
{
  throw std::exception("NYI: getInterpolant");
  return m.mkTrue();
}

Expression DeSAT::getModel(void) const
{
  throw std::exception("NYI: model extraction");
  return m.mkFalse();
}

bool DeSAT::addConstraint(CExpression &e)
{ 
  throw std::exception("NYI: addConstraint");
  return false;
}

signed DeSAT::addExtension(CExpression &e)
{
  throw std::exception("NYI: addExtension");
  return 0;
}

void DeSAT::showDistribution(void) const
{
  /*if (verbosity==0)
    return;*/

  if (n_partitions>16) return;

  printf("Distribution:");
  for (unsigned i=0; i<partitions.size(); i++)
    printf(" %d", partitions[i]->numClauses());
  printf(" [%d]\n", globalSolver->numClauses());

  std::vector< std::vector< unsigned > > shared_count(n_partitions);
  unsigned total = maxVar;

  for (unsigned v=1; v<=maxVar; v++)
  {
    bool exclusive=true;

    for (unsigned i=0; i<n_partitions; i++)
    {
      for (unsigned j=0; j<n_partitions; j++)
      {
        if (i==j) continue;

        if (shared_count[i].size()!=n_partitions)
          shared_count[i].resize(n_partitions, 0);

        if (sharedVariables.isShared(v, i, j))
        {
          exclusive=false;
          shared_count[i][j]++;
        }
      }
    }

    if (exclusive) total--;
  }

  printf("Global variables: %d (%.2f%%)\n", total, 100 * total/(double)maxVar);

  printf("Sharing matrix (%%):\n");  
  printf("Partition  ");
  for (unsigned i=0; i<n_partitions; i++)
    printf("  %02d", i);
  printf("\n");

  for (unsigned i=0; i<n_partitions; i++)
  {
    printf("       %02d  ", i);
    for (unsigned j=0; j<n_partitions; j++)
    {
      printf(" % 3.0f", (100 * shared_count[i][j]) / (double) maxVar);
    }
    printf("\n");
  }

  fflush(stdout);
}

bool DeSAT::solve(void)
{ 
  printf("Solving with %d variables, %d clauses and %d partitions on %d cores.\n", maxVar, numClauses(), n_partitions, n_cores);
  clock_t before = clock();

  if (n_partitions==1)
  {
    bool r = globalSolver->solve();
    globalTime += clock()-before;
    return r;
  }

  showDistribution();

  assert(globalSolver);  
  
  sharedVariables.update();

  rounds = 0;
  ModelValue res = M_UNDEF;

  while (res==M_UNDEF)
  {
    if (verbosity==2)
    {
      print("Press any key to continue...");
      ::getchar();
    }

    rounds++;

	//PROCESS_MEMORY_COUNTERS pmc;
	//GetProcessMemoryInfo( GetCurrentProcess(), &pmc, sizeof(pmc));

	//std::cout << " WorkingSetSize: " << pmc.WorkingSetSize / 1048576 << " MB" << std::endl;

    for (unsigned i=0; i<79; i++)
      print("=");
    print("\nRound %d\n", rounds);
    for (unsigned i=0; i<79; i++)
      print("-");
    print("\n");

    if (!solveGlobals())
      res = M_FALSE;
    else
    {
      if (solvePartitions())
      {
		all_sat_found++;
        if (!findDisagreement())
          res = M_TRUE;        
        importInterpolants(); // deletes the interpolant objects
      }
      else if (!importInterpolants())
        res = M_TRUE; // SAT!    
    }
  }

  //  for(signed i = 0; i < 1000; i++)
		//std::cout << "i: " << i << " " << m.toString(i) << std::endl;

	return res==M_TRUE;
}

bool DeSAT::solveGlobals(void)
{
  clock_t before = clock();
  print("Finding global assignment...\n");
  print("Trail size: %d\n", trail.size());

  for (unsigned i=0; i<trail.size(); i++)
  {
    print(" %d", trail[i]);
    if (reasons[i]) print("!");
  }
  print("\n");

  bool r=globalSolver->solve(trail);  
  
  while (!r && trail.size()>0)
  { 
    std::vector<signed> temp;
    ((MiniSAT_1p*)globalSolver)->getConflict(temp);
    
    signed last = 0;
    bool reason = true;
    bool inConflict= false;

    print("GLOBAL CONFLICT:");
    if (temp.size()==0)
      print (" EMPTY");
    else
      for (unsigned i=0; i<temp.size(); i++)
        print(" %d", temp[i]);
    print("\n");

    print("TRAIL:");
    for (unsigned i=0; i<trail.size(); i++)
    {
      print(" %d", trail[i]);
      if (reasons[i]) print("!");
    }
    print("\n");

    inConflict = std::find(temp.begin(), temp.end(), -last)!=temp.end();    
    while (trail.size()>0 && !inConflict)
    {
      last = trail.back();
      trail.pop_back();
      reason = reasons.back();
      reasons.pop_back();      
      inConflict = std::find(temp.begin(), temp.end(), -last)!=temp.end();
    }

    if (inConflict && !reason)
    {      
      trail.push_back(-last);
      reasons.push_back(true);
    }

    if (trail.size()==0 && reason)
      r = false;
    else
    {      
      print("TRAIL:");
      for (unsigned i=0; i<trail.size(); i++)
      {
        print(" %d", trail[i]);
        if (reasons[i]) print("!");
      }
      print("\n");
      r=globalSolver->solve(trail);
    }
  }
  globalTime += clock() - before;
  lastIterationTime = clock() - before;

  if (!r)
  {  
    print("No global assignment.\n");
  }
  else if (verbosity>1)
  {
    print("Global Assignment:");
    for (int v=1; v<=(signed)maxVar; v++)
      if (sharedVariables.isShared(v))
      {        
        ModelValue mv = globalSolver->get(v);
        if (mv != M_UNDEF)
         print(" %d", mv==M_TRUE ? v : -v );
      }
    print("\n");
  }

  return r;
}

bool DeSAT::solvePartitions(void)
{
  clock_t before = clock();
  print("Extending assignment...\n");
  bool all_sat = true;

  assumptions.clear();
  for (signed v=1; v<=(signed)maxVar; v++)
  {
    if (sharedVariables.isShared(v))
    {
      ModelValue mv = globalSolver->get(v);
      if (mv==M_UNDEF) continue;      
      assumptions.push_back( mv==M_TRUE ? v : -v );
    }
  }

  if (verbosity>0)
  {
    print("Assumptions: ");
    for (unsigned i=0; i<assumptions.size(); i++)  
      print(" %d", assumptions[i]);
    print("\n");
  }  

  have_error = false;
  have_bad_alloc = false;

  if(n_cores == 1){
  
	  for (int pid=0; pid<(signed)n_partitions; pid++)
  {    
      interpolants[pid] = m.mkNil();
      Expression t = partitions[pid]->getInterpolant(assumptions);
	  
      //std::cout << "T = " << t << " (" << partitions[pid]->em().get(t).left << "," << partitions[pid]->em().get(t).right << ")" << std::endl;

      if (!partitions[pid]->em().isTrue(t))
      {
            //std::cout << "duplicate" << std::endl;
            interpolants[pid] = m.duplicate(t, partitions[pid]->em());
            //std::cout << "done" << std::endl;
            all_sat = false;
			if(early_stop) break;
			
      }
	  }
  
  } else {

  #pragma omp parallel for default(shared) num_threads(n_cores)
  for (int pid=0; pid<(signed)n_partitions; pid++)
  {    
      DWORD_PTR mask = (0x01 << omp_get_thread_num());
      SetThreadAffinityMask( GetCurrentThread(), mask );    

	  interpolants[pid] = m.mkNil();
      interpolants[pid] = partitions[pid]->getInterpolant(assumptions);
	  
      //std::cout << "T = " << t << " (" << partitions[pid]->em().get(t).left << "," << partitions[pid]->em().get(t).right << ")" << std::endl;

    //  if (!partitions[pid]->em().isTrue(t))
    //  {
    //    #pragma omp critical
    //    {
    //      // Exceptions cannot pass a critical section, which can 
    //      // cause deadlocks. Therefore we have to catch them manually.
    //      try
    //      {
    //        //std::cout << "duplicate" << std::endl;
    //        interpolants[pid] = m.duplicate(t, partitions[pid]->em());
    //        //std::cout << "done" << std::endl;
    //        all_sat = false;
		  //}
    //      catch (std::bad_alloc &e)
    //      {
    //        have_bad_alloc = true;
    //        ba_exception = e;
    //      }
    //      catch (std::exception &e)
    //      {
    //        have_error = true;
    //        exception = e;
    //      }
    //      catch (...)
    //      {
    //        std::cout << "UNRECOVERABLE ERROR" << std::endl;
    //      }
    //    }
    //  }
  }

  for (int i=0; i<(signed)n_partitions; i++)
  {
	  Expression t = interpolants[i];
	  if (!partitions[i]->em().isTrue(t))
	  {        
		  // Exceptions cannot pass a critical section, which can 
		  // cause deadlocks. Therefore we have to catch them manually.
		  try
		  {
			  //std::cout << "duplicate" << std::endl;
			  interpolants[i] = m.duplicate(t, partitions[i]->em());
			  //std::cout << "done" << std::endl;
			  all_sat = false;
		  }
		  catch (std::bad_alloc &e)
		  {
			  have_bad_alloc = true;
			  ba_exception = e;
		  }
		  catch (std::exception &e)
		  {
			  have_error = true;
			  exception = e;
		  }
		  catch (...)
		  {
			  std::cout << "UNRECOVERABLE ERROR" << std::endl;
		  }
	  }
  }
  }

  partitionsTime += clock() - before;

  if (have_error)
    throw exception;
  else if (have_bad_alloc)
    throw ba_exception;

  return all_sat;
}

bool DeSAT::importInterpolants(void)
{
  clock_t before = clock();
  bool did_something = false;

  for (unsigned i=0; i<n_partitions; i++)
  {
    CExpression itp = interpolants[i];
    if (m.isNil(itp)){
		solutions_imported++;
		continue;
	}

    if (!m.isTrue(itp))
    {
      if (verbosity>0)
      {
        std::string t = m.toString(itp);
        print("Global interpolant import: %s\n", t.c_str());

			//std::string filename = "filename";
			//std::stringstream convert; // stringstream used for the conversion
			//convert << rounds;
			//convert << i;
			//filename += convert.str();
			//m.toDot(itp,filename);

      }      
      ((MiniSAT_1p*)globalSolver)->clearNewClauses();
      globalSolver->addConstraint(itp);      
      ((MiniSAT_1p*)globalSolver)->addNewClauses();
      interpolants_imported++;
      did_something=true;
    } 

    interpolants[i] = m.mkNil();
  }

  clock_t after = clock();
  importTime += after-before;

  return did_something;
}

bool DeSAT::findDisagreement(void)
{
  bool res = false;

  std::vector<signed> new_trail;

  for (unsigned i=0; i<n_partitions; i++)
  {
    for (signed v=1; v<=(signed)maxVar; v++)
    {      
      if (sharedVariables.occurs(v, i))
      {
        ModelValue mvi = partitions[i]->get(v);
        if (mvi != M_UNDEF)
        {
          for (unsigned j=i+1; j<n_partitions; j++)
          {
            if (sharedVariables.occurs(v, j))
            {            
              ModelValue mvj = partitions[j]->get(v);
              if (mvj != M_UNDEF && mvj != mvi)
              {
                signed x = (mvi==M_TRUE) ? v : -v;                

                if(std::find(trail.begin(), trail.end(), -x) != trail.end())
                  throw std::exception("Global assignment not respected");

                if (std::find(trail.begin(), trail.end(), x) == trail.end() && 
                    std::find(new_trail.begin(), new_trail.end(), x) == new_trail.end())
                  new_trail.push_back(x);
              }
            }
          }
        }
      }
    }
  }

  for (unsigned i=0; i<new_trail.size(); i++)
  {
    //print("Forcing model agreement on %d\n", new_trail[i]);
    trail.push_back(new_trail[i]);
    reasons.push_back(false);
  }

  return new_trail.size()>0;
}
