#define TI_NO_SRCPOS
#include "ti_config.h"
#include "tally-memory.h"
#include "ti-memory.h"
#define STOPTIFU_INTERNAL_DEFINE
#include "stoptifu-runtime.h"

static int debug_level;

#if 0
#define DEBUG_ONE 1
#define DEBUG_TWO 1
#define DEBUG_THREE 1
#define DEBUG_FOUR 1
#endif

#if DEBUG_ONE
#define DEBUG1(x) x
#else
#define DEBUG1(x) do { if (debug_level >= 1) x; } while (0)
#endif

#if DEBUG_TWO
#define DEBUG2(x) x
#else
#define DEBUG2(x) do { if (debug_level >= 2) x; } while (0)
#endif

#if DEBUG_THREE
#define DEBUG3(x) x
#else
#define DEBUG3(x) do { if (debug_level >= 3) x; } while (0)
#endif

#if DEBUG_FOUR
#define DEBUG4(x) x
#else
#define DEBUG4(x) do { if (debug_level >= 4) x; } while (0)
#endif

static ivseti *findem2(int n, int m, const int * const *v,
		       const int * const *inv, int k, int **r);

char *str(const int *v, int n)
{
  int i;
  char *s = ALLOC(char, 20 * n);
  sprintf(s, "[%d", v[0]);
  for (i = 1; i < n; i++)
    sprintf(s + strlen(s), ", %d", v[i]);
  sprintf(s + strlen(s), "]");
  return s;
}

/* Add y * p to the rectangle r in n-space. */
static void moverect(const int *p, int y, int **r, int n)
{
  int i, j;
  for (i = 0; i < 2; i++)
    for (j = 0; j < n; j++)
      r[i][j] += y * p[j];
}

/* Return all p s.t. lo <= k * p <= hi. */
static iseti findem_interval(int k, int lo, int hi)
{
  if (k == 0)
    return (lo <= 0 && 0 <= hi) ? iseti_universe() : iseti_emptyset();
  DEBUG3(printf("findem_interval(%d, %d, %d): ", k, lo, hi)); 
  if (k > 0) {
    lo = divide_int_rounding_to_plus_inf(lo, k);
    hi = divide_int_rounding_to_minus_inf(hi, k);
  } else {
    int t = divide_int_rounding_to_plus_inf(hi, k);
    hi = divide_int_rounding_to_minus_inf(lo, k);
    lo = t;
  }
  DEBUG3(printf("%d - %d\n", lo, hi)); 
  return iseti_single_interval(make_iinterval(lo, hi));
}

/* Find all p s.t. lo <= k * p <= hi.  Fail if that set is infinite.  
   Put min and max possible values for p in *plo and *phi. */
void compute_alpha_1D(int mult, int lo, int hi, int *plo, int *phi)
{
  if (mult == 0)
    if (lo <= 0 && 0 <= hi) {
      fprintf(stderr,"infinite\n");
      abort();
    } else {
      *plo = 1;
      *phi = 0;
    }
  else {
    DEBUG3(printf("compute_alpha_1D(%d, %d, %d): ", mult, lo, hi)); 
    if (mult > 0) {
      lo = divide_int_rounding_to_plus_inf(lo, mult);
      hi = divide_int_rounding_to_minus_inf(hi, mult);
    } else {
      int t = divide_int_rounding_to_plus_inf(hi, mult);
      hi = divide_int_rounding_to_minus_inf(lo, mult);
      lo = t;
    }
    DEBUG3(printf("%d - %d\n", lo, hi)); 
    *plo = lo;
    *phi = hi;
  }
}

/* r specifies a rectangle in n-space.  Put the selected corner of it in q. */
static void pickcorner(int **r, int corner, int *q, int n)
{
  int i;
  for (i = 0; i < n; i++)
    q[i] = r[(corner & (1 << i)) ? 1 : 0][i];
}

/* Multiply nxn matrix A by vector x and store the result in b. */
static void mult(int **A, int *x, int *b, int n)
{
  int j, i;
  for (i = 0; i < n; i++) {
    int sum = 0;
    for (j = 0; j < n; j++)
      sum += A[i][j] * x[j];
    b[i] = sum;
  }
}

/* Multiply the transpose of nxn matrix A by x and store the result in b. */
static void tmult(const int * const *A, const int *x, int *b, int n)
{
  int j, i;
  for (i = 0; i < n; i++) {
    int sum = 0;
    for (j = 0; j < n; j++)
      sum += A[j][i] * x[j];
    b[i] = sum;
  }
}

/* n is the number of dimensions. v is m vectors in n-space.  r is a
   rectangle (low point and high point) in n-space.  Search for lists of
   integers s.t. the linear combination of the vectors 
   specified by the list lies in the rectangle.  range is m intervals;
   range[i] specifies the range of possibilities that we might multiply
   v[i] by when searching for linear combinations.
*/
static ivseti *findem3(int n, int m, const int * const *v,
		       int **r, iinterval *range)
{
  DEBUG2(printf("findem3 n=%d m=%d v[0]=%s r=[%s : %s]\n", n, m,
		str(v[0], n), str(r[0], n), str(r[1], n)));
  if (m == 1)
    return findem2(n, 1, v, NULL, 0, r);
  else {
    int lo = range[0].lo, hi = range[0].hi, i;
    ivseti *result = ivseti_emptyset();
    for (i = lo; i <= hi; i++) {
      ivseti *subresult;
      DEBUG3(printf("%*sChecking %d\n", 10 - 2 * m, "", i));
      moverect(v[0], -i, r, n);
      subresult = findem3(n, m - 1, v + 1, r, range + 1);
      if (!ivseti_is_empty(subresult))
	result = ivseti_union(result, ivseti_cons(i, subresult));
      moverect(v[0], i, r, n);
    }
    return result;
  }
}

/* n is the number of dimensions. v is m vectors in n-space.  r is a
   rectangle (low point and high point) in n-space.  Find every list
   of integers s.t. the linear combination of the vectors 
   specified by the list lies in the rectangle. m should be n or 1.
*/
static ivseti *findem2(int n, int m, const int * const *v,
		       const int * const *inv, int k, int **r)
{
  if (m == 1) {
    int i;
    iseti result = iseti_emptyset();
    DEBUG1(printf("findem2 n=%d m=%d v[0]=%s r=[%s : %s]\n", n, m,
		  str(v[0], n), str(r[0], n), str(r[1], n)));
    /* Intersect all 1d solutions */
    for (i = 0; i < n; i++) {
      iseti o = findem_interval(v[0][i], r[0][i], r[1][i]);
      result = (i == 0) ? o : iseti_intersection(result, o);
    }
    return ivseti_flat(result);
  } else {
    int i, corners = 1 << n, corner,
      *c = ALLOC(int, n), *q = ALLOC(int, n);
    iinterval *range = ALLOC(iinterval, n);
    for (i = 0; i < n; i++) {
      range[i].lo = 0;
      range[i].hi = -1;
    }
    for (corner = 0; corner < corners; corner++) {
      pickcorner(r, corner, q, n);
      tmult(inv, q, c, n);
      DEBUG2(printf("corner %d (%s) maps to %s\n", corner, str(q, n), str(c, n)));
      for (i = 0; i < n; i++)
	include_in_iinterval(&(range[i]), c[i]);
    }
    for (i = 0; i < n; i++) {
      divide_iinterval(&(range[i]), k);
      DEBUG2(printf("intervals for i=%d: %d to %d\n", i, range[i].lo, range[i].hi));
    }
    return findem3(n, m, v, r, range);
  }
}

/* Iterate b through the rectangle [lo : hi].  Return false and set b to lo 
   if b >= hi.  Otherwise move b to the (lexicographically) next point 
   and return true. */
TI_INLINE(bump) 
bool bump(int *b, int *lo, int *hi, int n) {
  int i = n - 1;
  do
    if (++b[i] <= hi[i])
      return true;
    else
      b[i] = lo[i];
  while (--i >= 0);

  return false;
}

/* Verify a result computed by compute_alpha().  If verbose is false
then do not output anything unless verification fails. */
static void verify(int n, const int *p, const int * const *v,
		   int **r, ivseti *c, bool verbose)
{
  if (ivseti_is_empty(c))
    return;
  if (verbose) {
    printf("Verifying: c=");
    ivseti_print(c);
    puts("");
  }
  assert(ivseti_dim(c) == n);
  {
    int i, j, *x = ALLOC(int, n), *lo = ALLOC(int, n),
      *hi = ALLOC(int, n), *z = ALLOC(int, n); 
    for (i = 0; i < n ; i++) {
      lo[i] = x[i] = ivseti_min(c, i);
      hi[i] = ivseti_max(c, i);
    }
    /* x iterates over the bounding box of c. */ 
    do 
      if (ivseti_contains(c, x)) {
	/* set z to p plus a linear combination of the v's */
	for (i = 0; i < n; i++)
	  for (z[i] = p[i], j = 0; j < n; j++)
	    z[i] += x[j] * v[j][i];
	/* test it */
	for (i = 0; i < n; i++)
	  if (z[i] < r[0][i] || z[i] > r[1][i]) {
	    fprintf(stderr, "ERROR! Rectangle [%s : %s] does not contain %s",
		   str(r[0], n), str(r[1], n), str(p, n));
	    for (j = 0; j < n; j++)
	      fprintf(stderr, " + %d * %s", x[j], str(v[j], n));
	    fprintf(stderr, " = %s\n", str(z, n));
	    abort();
	  }
	if (verbose) {
	  printf("[%s : %s] does contain %s",
		 str(r[0], n), str(r[1], n), str(p, n));
	  for (j = 0; j < n; j++)
	    printf(" + %d * %s", x[j], str(v[j], n));
	  printf(" = %s\n", str(z, n));
	}
      }
    while (bump(x, lo, hi, n));
    FREE(x);
    FREE(hi);
    FREE(lo);
    FREE(z);
  }
}

/* n is the number of dimensions. p is a point in n-space. v is n
   vectors in n-space.  r is a rectangle (low point and high point) in
   n-space.  Find all linear combinations of the vectors in v s.t. the
   sum of p and the linear combination lies in the rectangle. 
   The product of matrices inv and v should equal k times I.
   Set debug_level to dbg.
*/
ivseti *ivseti_compute_alpha(int n, const int *p, const int * const *v,
			     const int * const *inv, int k, int **r, int dbg)
{
  ivseti *result;
  debug_level = dbg;
  /* Subtract off p from the rectangle, then solve the p = 0 case. */
  moverect(p, -1, r, n);
  result = findem2(n, n, v, inv, k, r);
  /* Restore r to its original value. */
  moverect(p, 1, r, n);
  if (debug_level > 0) {
    verify(n, p, v, r, result, debug_level > 1);
    {
      char *s = ivseti_to_string(result);
      printf("ivseti_compute_alpha: result is %s\n", s);
      FREE(s);
    }
  }
  return result;
}

#if 0
rectN_compute_alpha() is unused.  Instead, N calls to compute_alpha_1D are generated.

/* Based on ivseti_compute_alpha(), above.  Only works for the case
   where the result is known to be a dense 1D rectangle (interval).
   Set debug_level to dbg. */
rect1 rect1_compute_alpha(int n, const int *p, const int * const *v,
			  const int * const *inv, int k, int **r, int dbg)
{
  rect1 result;
  debug_level = dbg;
  int lo = r[0][0], hi = r[1][0], offset = *p;
  compute_alpha_1D(v[0][0], lo - offset, hi - offset, &result.l0, &result.h0);
#if 1
  {
    char *s = rect1_to_string(result);
    printf("rect1_compute_alpha: result is %s\n", s);
    FREE(s);
  }
  ivseti_compute_alpha(n, p, v, inv, k, r, (dbg > 1) ? dbg : 1);
#endif
  return result;
}

/* Based on ivseti_compute_alpha(), above.  Only works for the case
   where the result is known to be a dense 2D rectangle. 
   Set debug_level to dbg. */
rect2 rect2_compute_alpha(int n, const int *p, const int * const *v,
			  const int * const *inv, int k, int **r, int dbg)
{
  rect2 result;
  debug_level = dbg;
  
  /* Subtract off p from the rectangle, then solve the p = 0 case. */
  moverect(p, -1, r, n);
  if (v[0][0] == 0)
    rect2_helper(v[0][1], v[1][0], r, &result.l1, &result.h1, &result.l0, &result.h0);
  else
    rect2_helper(v[0][0], v[1][1], r, &result.l0, &result.h0, &result.l1, &result.h1);
  /* Restore r to its original value. */
  moverect(p, 1, r, n);

#if 1
  {
    char *s = rect2_to_string(result);
    printf("rect2_compute_alpha: result is %s\n", s);
    FREE(s);
  }
  ivseti_compute_alpha(n, p, v, inv, k, r, (dbg > 1) ? dbg : 1);
#endif

  return result;
}

/* Based on ivseti_compute_alpha(), above.  Only works for the case
   where the result is known to be a dense 3D rectangle.
   Set debug_level to dbg. */
rect3 rect3_compute_alpha(int n, const int *p, const int * const *v,
			  const int * const *inv, int k, int **r, int dbg)
{
  rect3 result;
  debug_level = dbg;
#if 1
  {
    char *s = rect3_to_string(result);
    printf("rect3_compute_alpha: result is %s\n", s);
    FREE(s);
  }
  ivseti_compute_alpha(n, p, v, inv, k, r, (dbg > 1) ? dbg : 1);
#endif
  return result;
}

#endif /* 0 */