/* code-foreach.cc */ /* Generate code for foreach loops. Most of the complexity comes from setting up the pointers and increments for strength reduced address calculations and from optimizing bounds checks. See Geoff Pike's thesis for a description of loop opts. */ #include #include #include #include #include #include #include #include "AST.h" #include "ArrayAccessSet.h" #include "CodeContext.h" #include "CtGlobal.h" #include "CtLocal.h" #include "CtReference.h" #include "CtType.h" #include "CtJavaArray.h" #include "PrimitiveDecl.h" #include "code-bounds.h" #include "code-call.h" #include "code-defs.h" #include "code-foreach.h" #include "code-grid.h" #include "code-util.h" #include "code.h" #include "compiler.h" #include "ctBox.h" #include "decls.h" #include "delimit.h" #include "domain-decls.h" #include "errors.h" #include "interface.h" #include "lgMacro.h" #include "lower.h" #include "osstream.h" #include "utils.h" #include "optimize.h" #include "code-MIVE.h" #include "buildexpr.h" #include "PolyToCode.h" #include "stats.h" #ifndef DEBUG_UPDATEPOINTBEFORESTMTNODE #define DEBUG_UPDATEPOINTBEFORESTMTNODE 1 #endif #ifdef HAVE_BINARY_ADAPTORS typedef stack< string, list< string > > StringStack; #else typedef stack< list< string > > StringStack; #endif extern bool bounds_checking; extern bool opt_usi; /* Temporary counters, etc. (these are reset once per file output) */ static int labelNum = 0; static int numLoops = 0; static map *mapLoopToNum = NULL; static map > *mapLoopToBases = NULL; static map *mapLoopToQ = NULL; static map > > *mapLoopBaseDimToDiff = NULL; typedef pair PtrType; static map > *mapLoopBaseToPtrType = NULL; static map *mapBaseToUse = NULL; #define atZero(base) ((base) + "zero") /* Called from aux-code.cc once per output file */ void reset_foreach_counters() { numLoops = labelNum = 0; delete mapLoopToNum; mapLoopToNum = NULL; delete mapLoopToBases; mapLoopToBases = NULL; delete mapLoopToQ; mapLoopToQ = NULL; delete mapLoopBaseDimToDiff; mapLoopBaseDimToDiff = NULL; delete mapLoopBaseToPtrType; mapLoopBaseToPtrType = NULL; delete mapBaseToUse; mapBaseToUse = NULL; } static int loopNumber(ForEachStmtNode *l) { if (numLoops == 0) mapLoopToNum = new map; int &result = (*mapLoopToNum)[l]; if (result == 0) result = ++numLoops; return result; } static void saveSRbase(ForEachStmtNode *l, const string &base, const string use, PtrType &t) { const int ln = loopNumber(l); if (mapLoopToBases == NULL) mapLoopToBases = new map >; (*mapLoopToBases)[ln].insert(base); if (debug_sr) cout << "saveSRbase: " << ln << ' ' << base << endl; if (mapLoopBaseToPtrType == NULL) mapLoopBaseToPtrType = new map >; (*mapLoopBaseToPtrType)[ln][base] = t; if (mapBaseToUse == NULL) mapBaseToUse = new map; (*mapBaseToUse)[base] = use; } static void saveSRdiff(ForEachStmtNode *l, const string &base, int dim, string diff) { if (mapLoopBaseDimToDiff == NULL) mapLoopBaseDimToDiff = new map > >; (*mapLoopBaseDimToDiff)[loopNumber(l)][base][dim] = diff; if (debug_sr) cout << "saveSRdiff: " << loopNumber(l) << ' ' << base << ' ' << dim << ' ' << diff << endl; } static void saveSRq(ForEachStmtNode *l, const string &q) { if (mapLoopToQ == NULL) mapLoopToQ = new map; (*mapLoopToQ)[loopNumber(l)] = q; if (debug_sr) cout << "saveSRq: " << loopNumber(l) << ' ' << q << endl; } static void resetTempCounter(int *c) { *c = 0; } #if 0 /* unused */ /* If f is a "normal" foreach or a "stripped" foreach then bump the counter. If a ForEachSetupNode is being emitted then do not. */ static void bumpTempCounter(ForEachStmtNode *f, int *c) { if (f->blockContext() == NULL || !f->shouldDeclare()) ++*c; } #endif const string pointField(const string p, int field, int arity) { if (arity == 1) { assert(field == 0); return p; } else return p + ".x" + int2string(field); } const string rectField(const string r, const string field, int arity) { return r + "." + MANGLE_TI_DOMAINS_RECTDOMAIN_FIELD_ACCESS(+, field, int2string(arity)); } string rectMin(const string &r, int dim, int arity) { return pointField(rectField(r, "p0", arity), dim, arity); } string rectMax(const string &r, int dim, int arity) { return "(" + pointField(rectField(r, "p1", arity), dim, arity) + " - 1)"; } string rectUpb(const string &r, int dim, int arity) { return pointField(rectField(r, "p1", arity), dim, arity); } string rectStride(const string &r, int dim, int arity) { return pointField(rectField(r, "loopStride", arity), dim, arity); } #if 0 /* unused */ static StringStack pointGuts(CodeContext& context, TreeNode *pt) { assert(!isPointNode(pt)); StringStack q; int arity = pt->type()->tiArity(); string index0 = pt->simpleVar(context); for (int i = 0; i < arity; i++) q.push(pointField(index0,i,arity)); return q; } #endif static const string indexCurrent(const string &q, int i) { return q + "_" + int2string(i); } static const string indexStart(const string &q, int i) { return q + "s_" + int2string(i); } static const string indexStride(const string &q, int i) { return q + "d_" + int2string(i); } static const string indexEnd(const string &q, int i) { return q + "e_" + int2string(i); } /* static StringStack currentPoint(const string &q, int arity) { StringStack s; for (int i = 0; i < arity; i++) s.push(indexCurrent(q, i)); return s; } */ /* i is a 1-based index. l is a foreach loop. Return a string for the current value of the iteration point in dimension i. */ string *& iterationPoint(TreeNode *l, int i) { static map_tree_int_to_pstring m; return m[l][i]; } static void foreachRectDomainStart(CodeContext &context, int arity, const string &q, TreeNode *dom, ForEachStmtNode *f, string &curr_rect) { // Case 1: Iterating over a known (possibly infinite) set of integers if (f->stride() != NULL) return; // Case 2: RectDomain specific startup if (isObjectNode(dom) || isOFAN(dom) || isTypeFAN(dom)) { curr_rect = dom->emitExpression(context); } else if (isGridMethod(dom, "domain")) { curr_rect = dom->child(0)->child(0)->emitExpression(context) + ".domain"; } else { string exp = dom->emitExpression(context); curr_rect = q + "_curr_rect"; context.declare(curr_rect, RectDomainNDecl[arity - 1]->cType()); context << curr_rect << " = " << exp << ";" << endCline; } } /* NB: this code may insert decls directly instead of using context.declare() because otherwise the decls might appear in the wrong place. A better way to do this would probably be to add a new method (e.g., declare_here) to CodeContext and use that. */ static void foreachGeneralStart(CodeContext &context, int arity, const string &q, TreeNode *dom, bool _shouldDeclare) { // General domain specific startup ClassDecl &d = *DomainNDecl[ arity - 1 ]; const CtGlobal &d_inst = *new CtGlobal( d.cType() ); context.depend(d_inst); context.depend(*new CtGlobal(MultiRectADomainNDecl[ arity - 1 ]->cType())); const string d_var = q + "_d"; const string curr_rect_var = q + "_curr_rect"; const string arity_str = int2string(arity); if (_shouldDeclare) context << d_inst << ' ' << d_var << ';' << endCline; else context.declare( d_var, d_inst ); // new interface: // RectDomain [] local rdarray = d.getRectsJArray() // if (rdarray) { // non-empty // int rdcnt = rdarray.length; // RectDomain local *rdptr = &rdarray[0]; // while (rdcnt--) { // curr_rect = *(rdptr++); ClassDecl &rd = *RectDomainNDecl[ arity - 1 ]; const CtType &jarray_unboxed = *(new CtJavaArray(rd.cType())); const CtType &rdjarray = *(new CtLocal(jarray_unboxed)); context.depend(jarray_unboxed); context.depend(rdjarray); const string rdarr_var = q + "_rdarr"; const string rdcnt_var = q + "_rdcnt"; const string rdptr_var = q + "_rdptr"; if (_shouldDeclare) { context << rdjarray << ' ' << rdarr_var << ';' << endCline; } else { context.declare( rdarr_var, rdjarray ); } const string exp = dom->emitExpression(context); context << d_var << " = " << exp << ';' << endCline << rdarr_var << " = " << MANGLE_TI_DOMAINS_DOMAIN_DISPATCH(<<, arity_str, "getRectsJArray", "") << "(*(PT19tiMultiRectADomain" << arity_str << "7domains2ti*)&" << d_var << ");" << endCline; context << "if (" << rdarr_var << " != NULL) {" << endCline << " int " << rdcnt_var << ';' << endCline << " " << rd.cType() << " *" << rdptr_var << ';' << endCline << " JAVA_ARRAY_LENGTH_LOCAL(" << rdcnt_var << ", " << rdarr_var << ");" << endCline << " JAVA_ARRAY_ADDR_LOCAL(" << rdptr_var << ", " << rdarr_var << ", 0, \"foreach over general domain\");" << endCline << " while (" << rdcnt_var << "--) {" << endCline << " " << makeRectDomainType(arity)->cType() << " " << curr_rect_var << " = " << "*(" << rdptr_var << "++);" << endCline; #if 0 // old, linked-list based interface: // iter = d.getRectangles().newIterator() // do curr_rect = iter[i] #define MANGLE_TI_DOMAINS_RDL_NEWITERATOR(sep, arity) \ "m11newIteratormT17tiRectDomainList" sep arity sep "7domains2ti" #define MANGLE_TI_DOMAINS_RDL_ISEND(sep, arity) \ "m5isEndmT25tiRectDomainListIterator" sep arity sep "7domains2ti" #define MANGLE_TI_DOMAINS_RDL_ADVANCE(sep, arity) \ "m7advancemT25tiRectDomainListIterator" sep arity sep "7domains2ti" #define MANGLE_TI_DOMAINS_RDL_DEREFERENCE(sep, arity) \ "m11dereferencemT25tiRectDomainListIterator" sep arity sep "7domains2ti" ClassDecl &rdl = *RectDomainListNDecl[ arity - 1 ]; ClassDecl &iter = *RectDomainListIteratorNDecl[ arity - 1 ]; // DOB: commented-out code implements dynamic dispatch for Domain, which we may decide to put back some day #if 0 const CtType &d_desc = d.cDescriptorType(); const CtType &rdl_desc = rdl.cDescriptorType(); const CtType &iter_desc = iter.cDescriptorType(); const CtGlobal &gl_d_desc = *new CtGlobal( *new CtLocal( d_desc ) ); const CtGlobal &gl_rdl_desc = *new CtGlobal( *new CtLocal( rdl_desc ) ); const CtGlobal &gl_iter_desc = *new CtGlobal( *new CtLocal( iter_desc ) ); #endif const CtGlobal &rdl_inst = *new CtGlobal( rdl.cType() ); const CtGlobal &iter_inst = *new CtGlobal( iter.cType() ); const string mdl_var = q + "_mdl"; const string d_iter_var = q + "_d_iter"; const string iter_cont_var = q + "_iter_cont"; context.depend(rdl_inst); context.depend(iter_inst); #if 0 const string d_desc_var = q + "_d_desc"; const string mdl_desc_var = q + "_mdl_desc"; const string d_iter_desc_var = q + "_d_iter_desc"; context.depend( d_desc ); context.depend( rdl_desc ); context.depend( iter_desc ); context.depend( gl_d_desc ); context.depend( gl_rdl_desc ); context.depend( gl_iter_desc ); #endif if (_shouldDeclare) { context << rdl_inst << ' ' << mdl_var << ';' << endCline; context << iter_inst << ' ' << d_iter_var << ';' << endCline; context << PrimitiveDecl::BoolDecl.cType() << ' ' << iter_cont_var << ';' << endCline; } else { context.declare( mdl_var, rdl_inst ); context.declare( d_iter_var, iter_inst ); context.declare( iter_cont_var, PrimitiveDecl::BoolDecl.cType() ); } #if 0 context << d_desc << " *" << d_desc_var << ';' << endCline << rdl_desc << " *" << mdl_desc_var << ';' << endCline << iter_desc << " *" << d_iter_desc_var << ';' << endCline; #endif const string exp = dom->emitExpression(context); context << d_var << " = " << exp << ';' << endCline //<< "CLASS_INFO_GLOBAL(" << d_desc_var << ", " //<< d_desc << ", " << d_var << ");" << endCline //<< mdl_var << " = " << d_desc_var << "->" << MANGLE_TI_DOMAINS_DOMAIN_GETRECTANGLES << d_var << ");" << endCline << mdl_var << " = " << MANGLE_TI_DOMAINS_DOMAIN_DISPATCH(<<, arity_str, "getRectangles", "") << "(*(PT19tiMultiRectADomain" << arity_str << "7domains2ti*)&" << d_var << ");" << endCline //<< "CLASS_INFO_GLOBAL(" << mdl_desc_var << ", " //<< rdl_desc << ", " << mdl_var << ");" << endCline //<< d_iter_var << " = " << mdl_desc_var << "->" << MANGLE_TI_DOMAINS_RDL_NEWITERATOR << mdl_var << ");" << endCline << d_iter_var << " = " << MANGLE_TI_DOMAINS_RDL_NEWITERATOR(<<,arity_str) << "(" << mdl_var << ");" << endCline //<< "CLASS_INFO_GLOBAL(" << d_iter_desc_var << ", " //<< iter_desc << ", " << d_iter_var << ");" << endCline << "do {" << endCline << " " << iter_cont_var << " = !(" << MANGLE_TI_DOMAINS_RDL_ISEND(<<,arity_str) << "(" << d_iter_var << "));" << endCline << " if (" << iter_cont_var << ") {" << endCline << " " << makeRectDomainType(arity)->cType() << " " << curr_rect_var << ";" << endCline << " " << curr_rect_var << " = " << MANGLE_TI_DOMAINS_RDL_DEREFERENCE(<<,arity_str) << "(" << d_iter_var << ");" << endCline; #endif } static void updateIterationPoint(CodeContext &os, const string &q, const string &iterationVar, int arity) { os << MANGLE_TI_DOMAINS_POINT_CONSTRUCT(<<, int2string(arity)) << MANGLE_STACK_VAR(<<, iterationVar); for (int i = 0; i < arity; i++) os << ", " << indexCurrent(q, i); os << ");" << endCline; } // Set the point-valued variable to the values in indices. static void updatePoint(CodeContext &os, const string &v, StringStack indices) { int arity = indices.size(); os << MANGLE_TI_DOMAINS_POINT_CONSTRUCT(<<, int2string(arity)) << v; for (int i = 0; i < arity; i++, indices.pop()) os << ", " << indices.top(); os << ");" << endCline; } static string quickconv(StringStack indices, const string& arr, CodeContext &context, map_int_to_treeSet &provablyUnitStride, const TreeNode *array) { string s; for (int i = indices.size(); --i >= 0; indices.pop()) { string d = difference(indices.top(), arr + ".base[" + int2string(i) + "]", context); string q = contains(provablyUnitStride[i], array) ? d : quotient_with_runtime_unitdivisor_opt(d, arr + ".stride[" + int2string(i) + "]", context); /* if (DEBUG_BC) if (contains(provablyUnitStride[i], array)) cout << "Saved division by 1 (" << arr << ".stride[" << i << "])" << endl; */ string p = product(q, arr + ".sideFactors[" + int2string(i) + "]", context); s = (s == "" ? p : sum(s, p, context)); } return s; } // Compute how far in memory we move (in units of sizeof(array elt)) // when the index of an array index changes by d times multiplier in // the i'th dimension. (i counts from 0.) If bounds checking is on // AND requireCheck is true then check that d times multiplier is // a multiple of the stride of the array in dimension i. static string convertDiffToOffset(const Poly *d, int i, ForEachStmtNode *WRTloop, const string &arr, const string &multiplier, const string &requireCheck, CodeContext &context, map_int_to_treeSet &provablyUnitStride, const TreeNode *ad, bool partialDomain) { string result; if (d->constantp()) result = int2string(d->asInt()); else result = PolyToCode(d, WRTloop, context); if (result != "0") { TypeNode *type = ad->dtype(); if (partialDomain) context << "if (!(" << callGridMethod(*type, "isnull", "(" + arr + ")") << ")) {" << endCline; current_context = &context; result = product(result, multiplier); if (contains(provablyUnitStride[i], ad)) { /* if (DEBUG_BC) cout << "Saved division by 1" << endl; */ } else { if (!partialDomain && requireCheck != "0" && bounds_checking) { context << "#if BOUNDS_CHECKING" << endCline << "if (" << requireCheck << ")" << endCline << " "; int arity = type->tiArity(); const string stride = rectStride(arr + ".domain", i, arity); const string condition = "((" + result + ") % (" + stride + ")) == 0"; generateBoundsAssert(context, condition, arr, type, i, WRTloop, result, " must be a multiple of ", stride); if (DEBUG_BC) cout << "assert(" << condition << ")" << endl; context << "#endif" << endCline; } result = quotient_with_runtime_unitdivisor_opt(result, arr + ".stride[" + int2string(i) + "]"); } result = product(result, arr + ".sideFactors[" + int2string(i) + "]"); if (partialDomain) context << "}" << endCline; } return result; } // i and j are 0-based indices static string computeDelta(const MIVE *m, int i, ForEachStmtNode *WRTloop, int j, const string &arr, const string &domainStride, const string &requireCheck, CodeContext &context, map_int_to_treeSet &provablyUnitStride, const TreeNode *ad, bool partialDomain) { const Poly *d = deriv(m, i, WRTloop, j); assert(d != NULL); string result = convertDiffToOffset(d, i, WRTloop, arr, domainStride, requireCheck, context, provablyUnitStride, ad, partialDomain); if (debug_sr) cout << "compute change in " << stringify(m->p[i]) << " as " << pseudocode(WRTloop->vars()->child(0)->simpName()) << "[" << j+1 << "]" " changes: " << result << " (deriv is " << stringify(d) << ")" << endl; return result; } /* Return C code for evaluating ARRAY[POINT] for a titanium array and a point of the appropriate arity. Unless the C code is compiled with -DBOUNDS_CHECKING=0 then, at run time, the program will verify that the array's domain contains the point. If the domain doesn't contain the point then the program halts with an error message that includes the WHERE string. */ static string safePointer(const string &array, const string &point, TypeNode &t, const string &where) { return callGridMethod(t, "_PTR", "(\"" + where + "\", " + array + ", " + point + ")"); } #define DECLARE(context, v, type) \ do { if (_shouldDeclare) (context).declare((v), (type)); } while (0) static void setupConstantPointer(TreeNode *arr, TreeNode *ind, ArrayAccessSet *a, ForEachStmtNode *WRTloop, CodeContext &context, bool partialDomain, bool _shouldDeclare) { TreeNode *acc = a->use(arr, ind); string ptr = LIPTempname(WRTloop, arr, ind); TypeNode *t = arr->dtype(); int arrayArity = t->tiArity(); bool isGlobal = !(t->modifiers() & Common::Local); const CtType &eltType = t->elementType()->cType(); DECLARE(context, ptr, ctBox(eltType, !isGlobal)); const string array = acc->array()->emitExpression(context); const string p = string("loop at ") + WRTloop->position().asString(); if (partialDomain) context << "if (!(" << callGridMethod(*t, "isnull", "(" + array + ")") << ")) {" << endCline; /* If possible, use PolyToCode to calculate the index. That way, variables whose value is known need not be evaluated. */ { CodeContext subcontext(context); string index; MIVEcontext *e = MIVEcontextOfLoop(WRTloop); const MIVE *m = acc->index()->getMIVE(WRTloop); if (m != NULL && m->p != NULL) { StringStack s; for (int i = arrayArity-1; i >= 0; i--) s.push(PolyToCode(m->p[i], e, subcontext)); index = "index_temporary"; subcontext.declare(index, makePointType(arrayArity)->cType()); updatePoint(subcontext, index, s); } if (index == "") index = acc->index()->emitExpression(subcontext); subcontext << ptr << " = " << safePointer(array, index, *t, p) << ';' << endCline; } if (partialDomain) context << "}" << endCline; } #define START(ns) (string("s") + (ns)) #define DIFF(ns, dim) (string("d") + (ns) + '_' + int2string(dim)) #define nonTrivialSize(dim) \ ('(' + indexCurrent(q, dim) + " + " + \ indexStride(q, dim) + " < " + \ indexEnd(q, dim) + ')') void setupOSR(TreeNode *arr, const MIVE *base, const MIVE *dest, ArrayAccessSet *a, ForEachStmtNode *WRTloop, CodeContext &context, map_int_to_treeSet &provablyUnitStride, TreeNode *ad, bool partialDomain, bool _shouldDeclare) { if (debug_sr) cout << "setupOSR(loop at " << WRTloop->position().asString() << ")" << endl; const string array = arr->emitExpression(context); if (partialDomain) context << "if (!(" << callGridMethod(*(arr->type()), "isnull", "(" + array + ")") << ")) {" << endCline; const CtType &eltType = arr->type()->elementType()->cType(); string result = "0"; // m, the thing we're computing, is dest - base. const MIVE *m = dest->add(base->negate()); // convert m to distance in memory for (int i = 0; i < m->arity(); i++) { string s = convertDiffToOffset(m->p[i], i, WRTloop, array, "1", "1", context, provablyUnitStride, ad, partialDomain); result = sum(result, s, context); } string t = a->OSRTempname(WRTloop, arr, dest); DECLARE(context, t, theIntType->cType()); if (debug_sr && _shouldDeclare) cout << "setupOSR: declare " << t << endl; result = product(result, string("sizeof(") + eltType + string(")"), context); context << t << " = " << result << ';' << endCline; if (partialDomain) context << "}" << endCline; } /* CONTEXT is the CodeContext for outputting code. M contains the pairs that we care about for strength reduction. ARRAYACCESSES are the optimized access in the loop body WRT to this loop. L is this loop (ForEachStmtNode). Q is a string for uniquely identifying a family of C variables. DIM is the dimension to work on---SRInit() is called once per dimension. ARITY is the arity of the rectangle. INCR is modified (by appending) to contain code needed to increment the pointers that are used by strength reduction. PROVABLYUNITSTRIDE may contain an array iff that array provably has unit stride in dimension i. */ static void SRInit(CodeContext &context, map_tree_to_cMIVElist &m, ArrayAccessSet *arrayAccesses, ForEachStmtNode *l, const string &q, int dim, int arity, string &incr, map_int_to_treeSet &provablyUnitStride, bool partialDomain, bool _shouldDeclare, bool saveSRinfo) { if (arrayAccesses == NULL) return; if (debug_sr) cout << "SRInit(loop at " << l->position().asString() << ")" << endl; bool noneHandled = true; /* Part 1: Compute initial value for the SR pointers. */ for (map_tree_to_cMIVElist::const_iterator a = m.begin(); a != m.end(); a++) { TreeNode *acc = arrayAccesses->use((*a).first); TreeNode *arr = acc->array(); const string array = arr->emitExpression(context); foreach (z, llist, *((*a).second)) { const MIVE *p = *z; TypeNode *t = arr->type(); bool isGlobal = !(t->modifiers() & Common::Local); const CtType &eltType = t->elementType()->cType(); const int arrayArity = t->tiArity(); const string arityStr = int2string(arrayArity); const string R = arrayMethodPrefix(t) + "domain(" + eltType + ", " + arityStr + ")(" + array + ")"; const string ns = SRTempname(l, (llist * &) (*a).second, acc, p); const string prevbase = SRvar_at_depth(ns, dim - 1); const string base = SRvar_at_depth(ns, dim); /* cout << "Iteration Arity = " << arity << ";" " Array Arity = " << arrayArity << endl; */ assert(p->arity() == arrayArity); noneHandled = false; // Make a note, for later, of how to increment the pointer when // moving in this direction. incr += lgMacro("INDEX", *t) + "(" + base + ", " + base + ", " + DIFF(ns, dim) + "); "; if (dim == 0) { DECLARE(context, base, ctBox(eltType, !isGlobal)); if (saveSRinfo) { PtrType p(t->elementType(), !isGlobal); saveSRbase(l, base, SRvar(ns, arity), p); for (int d = 0; d < arity; d++) saveSRdiff(l, base, d, DIFF(ns, d)); } StringStack s; for (int i = 0; i < arrayArity; i++) { s.push(PolyToCode(p->p[i], l, context)); } if (partialDomain) context << "if (!(" << callGridMethod(*t, "isnull", "(" + array + ")") << ")) {" << endCline; string qc = quickconv(s, array, context, provablyUnitStride, arrayVarDecl(arr)); context << lgMacro("INDEX", *t) << '(' << base << ", " << array << ".A, " << qc << ");" << endCline; if (partialDomain) context << "}" << endCline; } else { DECLARE(context, base, ctBox(eltType, !isGlobal)); context << base << " = " << prevbase << ";" << endCline; } } } /* Part 2: Constant pointers */ if (dim == 0) { context << "/* loop invar pointers */" << endCline; for (ArrayAccessSet::lip_iterator i = arrayAccesses->lip_begin(); !i.isDone(); i.next()) setupConstantPointer((TreeNode *) i.array(), (TreeNode *) i.index(), arrayAccesses, l, context, partialDomain, _shouldDeclare); context << "/* end of loop invar pointer setup */" << endCline; } if (noneHandled) return; /* Part 3: OSR */ if (opt_osr) { if (dim == 0) { context << "/* OSR calculations */" << endCline; for (ArrayAccessSet::osr_iterator i = arrayAccesses->begin(); !i.isDone(); i.next()) setupOSR(arrayAccesses->use(i.array())->array(), i.base(), i.dest(), arrayAccesses, l, context, provablyUnitStride, i.array(), partialDomain, _shouldDeclare); context << "/* end of OSR calculations */" << endCline; } } else context << "/* OSR disabled */" << endCline; /* Part 4: incrementing the pointers */ // Compute the (pointer) increments necessary for each dimension. // Output that calculation inside an "if (nontrivialsize()) { ... }". // If the size of the domain is trivial then we don't need the delta // per iteration because there will only be zero or one iterations. if (dim == 0) { for (int d = 0; d < arity; d++) { for (map_tree_to_cMIVElist::const_iterator a = m.begin(); a != m.end(); a++) { TreeNode *acc = arrayAccesses->use((*a).first); TreeNode *arr = acc->array(); const string array = arr->emitExpression(context); foreach (z, llist, *((*a).second)) { const MIVE *p = *z; const string ns = SRTempname(l, (llist * &) (*a).second, acc, p); DECLARE(context, DIFF(ns, d), theIntType->cType()); if (use_nonTrivialSizeCheck) context << "if " << nonTrivialSize(d) << " {"; for (int i = 0; i < p->arity(); i++) { string NTS = use_nonTrivialSizeCheck ? string("1") : nonTrivialSize(d); string domainStride = indexStride(q, d); string s = computeDelta(p, i, l, d, array, domainStride, NTS, context, provablyUnitStride, arrayVarDecl(arr), partialDomain); context << DIFF(ns, d) << (i == 0 ? " = " : " += ") << s << ';' << endCline; } if (use_nonTrivialSizeCheck) context << "}" << endCline; } } } } } #undef DIFF #undef nonTrivialSize static int countUSI(map_int_to_treeSet &provablyUnitStride) { int count = 0; for (map_int_to_treeSet::const_iterator i = provablyUnitStride.begin(); i != provablyUnitStride.end(); ++i) if (i->second) count += i->second->size(); return count; } #define middle_label(i) "middle" << labelNum << "_" << (i) #define top_label(i) "top" << labelNum << "_" << (i) /* ARITY is the arity of the rectangle. Q is a string for uniquely identifying a family of C variables. DOM is the domain of the iteration (not necessarily just this rectangle). VAR is the iteration variable. STMT is the body of the loop. NEEDPOINT is whether to update the point-valued iteration var explicitly; it may be modified when codeGen() is called on STMT. CURR_RECT is a string for the C variable over which we're iterating now. CANBEEMPTY is whether curr_rect can have 0 points in it. M contains the pairs that we care about for strength reduction. ARRAYACCESSES are the optimized access in the loop body WRT to this loop. L is this loop (ForEachStmtNode). PARTIALDOMAIN is whether this is a partial domain loop. */ static void foreachRectIterate(CodeContext &context, int arity, string q, TreeNode *dom, TreeNode *var, TreeNode *stmt, bool &needPoint, string &curr_rect, bool canBeEmpty, map_tree_to_cMIVElist &m, ArrayAccessSet *arrayAccesses, ForEachStmtNode *l, bool partialDomain, bool _shouldDeclare, bool saveSRinfo) { /* Approximate idea: iterate over RectDomain curr_rect s = getLoopingStride; p0 = getLowerBound; p1 = getUpperBound; for (var0, p0.get(0), p1.get(0), s.get(0)) ... for (var-n, p0.get(n), p1.get(n), s.get(n)) do stuff ... end for var-n end for var0 */ int i; /* Ensure unique labels for gotos (if any). */ labelNum++; // Preliminaries // Set up variables for the current value, start, stride, and end of // each dimension of the iteration. Exception: for the outermost // loop of the iteration we don't need to keep the start around. // Warning: This is relies on knowledge about the domain library // implementation because it pulls the rectdomain variables directly // from the struct. for (i = 0; i < arity; i++) { DECLARE(context, indexCurrent(q, i), theIntType->cType()); iterationPoint(l, i + 1) = new string(indexCurrent(q, i)); context << indexCurrent(q, i) << " = " << pointField(rectField(curr_rect, "p0", arity), i, arity) << ';' << endCline; if (i > 0) { DECLARE(context, indexStart(q, i), theIntType->cType()); context << indexStart(q, i) << " = " << indexCurrent(q, i) << ";" << endCline; } DECLARE(context, indexStride(q, i), theIntType->cType()); context << indexStride(q, i) << " = " << pointField(rectField(curr_rect, "loopStride", arity), i, arity) << ';' << endCline; DECLARE(context, indexEnd(q, i), theIntType->cType()); context << indexEnd(q, i) << " = " << pointField(rectField(curr_rect, "p1", arity), i, arity) << ';' << endCline; } if (canBeEmpty) { context << "if ("; // Warning: This is relies on knowledge about the domain library // implementation. #if 0 // OLD RD REP // If the start and end values are all 0, this is a null domain. for (i = 0; i < arity; i++) { if (i != 0) context << " || "; context << "((" << indexCurrent(q, i) << " != 0) || " "(" << indexEnd(q, i) << " != 0))"; } #else // NEW RD REP // Empty RD's always represented with a negative stride point context << indexStride(q,0) << " > 0"; #endif context << ") {" << endCline; } string *incr = new string [arity]; const string &iterationVar = *(var->ident()); /* We may be able to determine that certain arrays have unit stride in certain dimensions. */ map_int_to_treeSet provablyUnitStride; if (!partialDomain && arrayAccesses != NULL) { if (opt_usi) arrayAccesses->seekUnitStride(provablyUnitStride, l); if (bounds_checking) { /* At run time, check that the strides that must be 1 are 1. */ /* (Unless bounds checking is off.) */ generateUnitStrideCheck(context, l, arrayAccesses, provablyUnitStride); generateBoundsChecks(context, l, new StringRect(curr_rect, arity), arrayAccesses, provablyUnitStride); } } // top of the loop for (i = 0; i < arity; i++) { SRInit(context, m, arrayAccesses, l, q, i, arity, incr[i], provablyUnitStride, partialDomain, _shouldDeclare, saveSRinfo); if (i == 0) context << "/* Top of loop */" << endCline; // if i is zero we skip this, because the FOOs_BAR variable doesn't // exist---the correct value is already in FOO_BAR. if (i > 0) context << indexCurrent(q, i) << " = " << indexStart(q, i) << ";" << endCline; if (use_goto) { context << "goto " << middle_label(i) << ";" << endCline; context << top_label(i) << ": ;" << endCline; if (!incr[i].empty()) context << incr[i] << endCline; context << middle_label(i) << ": ;" << endCline; } else { context << "do {" << endCline; } } // body of the loop { CodeContext subcontext(context); stmt->codeGen(subcontext); /* If we need the iteration point, insert code in front of the subcontext that sets it at the start of each iteration. */ if (needPoint) updateIterationPoint(context, q, iterationVar, arity); if (stat_foreach) loopStats(stmt, dom->type(), partialDomain, needPoint, countUSI(provablyUnitStride), l->numLifted(), cout); } // bottom of the loop for (i = arity; --i >= 0; ) { if (use_goto) context << "if ((" << indexCurrent(q, i) << " += " << indexStride(q, i) << ") < " << indexEnd(q, i) << ") " "goto " << top_label(i) << ";" << endCline; else { context << "if ((" << indexCurrent(q, i) << " += " << indexStride(q, i) << ") >= " << indexEnd(q, i) << ") break;" << endCline; if (!incr[i].empty()) context << incr[i] << endCline; context << "} while (1);" << endCline; } } if (canBeEmpty) context << "}" << endCline; delete[] incr; } #undef middle_label #undef top_label /* ARITY is the arity of the rectangle. Q is a string for uniquely identifying a family of C variables. LO, HI, and STRIDE specify the range to iterate over. VAR is the iteration variable. STMT is the body of the loop. NEEDPOINT is whether to update the point-valued iteration var explicitly; it may be modified when codeGen() is called on STMT. M contains the pairs that we care about for strength reduction. ARRAYACCESSES are the optimized access in the loop body WRT to this loop. L is this loop (ForEachStmtNode). PARTIALDOMAIN is whether this is a partial domain loop. */ static void foreachRangeIterate(CodeContext &context, int arity, string q, TreeNode *lo, TreeNode *hi, TreeNode *stride, TreeNode *var, TreeNode *stmt, bool &needPoint, map_tree_to_cMIVElist &m, ArrayAccessSet *arrayAccesses, ForEachStmtNode *l, bool partialDomain, bool _shouldDeclare, bool saveSRinfo) { assert(arity == 1); const string &iterationVar = *(var->ident()); string los, his, strides; // strings for lo, hi, and stride // Preliminaries // Set up variables for the current value, stride, and end of // the iteration. It MUST be a 1D iteration. DECLARE(context, indexCurrent(q, 0), theIntType->cType()); iterationPoint(l, 1) = new string(indexCurrent(q, 0)); if (lo != NULL && !lo->absent()) { los = lo->emitExpression(context); context << "/* range start */ " << indexCurrent(q, 0) << " = " << los << ';' << endCline; } DECLARE(context, indexStride(q, 0), theIntType->cType()); if (stride != NULL && !stride->absent()) { strides = stride->emitExpression(context); context << "/* range stride */ " << indexStride(q, 0) << " = " << strides << ';' << endCline; } DECLARE(context, indexEnd(q, 0), theIntType->cType()); if (hi != NULL && !hi->absent()) { his = hi->emitExpression(context); context << "/* range end */ " << indexEnd(q, 0) << " = " << his << ';' << endCline; } if (DEBUG_OPT) cout << "foreachRangeIterate(" << l->position().asString() << "):" " lo=" << los << " hi=" << his << " stride=" << strides << " partialDomain=" << partialDomain << endl; /* We may be able to determine that certain arrays have unit stride in certain dimensions. */ map_int_to_treeSet provablyUnitStride; if (!partialDomain && arrayAccesses != NULL) { if (opt_usi) arrayAccesses->seekUnitStride(provablyUnitStride, l); if (bounds_checking) { /* At run time, check that the strides that must be 1 are 1. */ /* (Unless bounds checking is off.) */ generateUnitStrideCheck(context, l, arrayAccesses, provablyUnitStride); generateBoundsChecks(context, l, new StringRect(los, his), arrayAccesses, provablyUnitStride); } } // top of the loop string incr; SRInit(context, m, arrayAccesses, l, q, 0, arity, incr, provablyUnitStride, partialDomain, _shouldDeclare, saveSRinfo); context << "/* Top of loop (1D foreach on range) */" << endCline; assert(his == ""); if (strides == "") context << "for (;;) "; else context << "for ( ; ; " << indexCurrent(q, 0) << " += " << strides << ") "; context << '{' << endCline; // body of the loop { CodeContext subcontext(context); stmt->codeGen(subcontext); /* If we need the iteration point, insert code in front of the subcontext that sets it at the start of each iteration. */ if (needPoint) updateIterationPoint(context, q, iterationVar, arity); if (stat_foreach) loopStats(stmt, NULL, partialDomain, needPoint, countUSI(provablyUnitStride), l->numLifted(), cout); } // bottom of the loop context << incr << endCline << '}' << endCline << "/* Bottom of loop (1D foreach on range) */" << endCline; } #undef DECLARE static void foreachGeneralEnd(CodeContext &os, int arity, string q) { // General domain specific tail os << " } }" << endCline; #if 0 // advance curr_rect // end while curr_rect os << " " << MANGLE_TI_DOMAINS_RDL_ADVANCE(<<, int2string(arity)) "(" << q << "_d_iter);" << endCline; os << " }" << endCline; os << "} while (" << q << "_iter_cont);" << endCline; #endif } //////////////////////////////////////////////////////////////////////////// // Helpers for void UpdatePointBeforeStmtNode::emitStatement(). //////////////////////////////////////////////////////////////////////////// extern string SRAstr(TreeNode *t); static bool isMatchingSRA(TreeNode *t, const string &s, bool e) { return (isSRArrayAccessNode(t) || isOSRArrayAccessNode(t)) && (!e || t->appearsOnEveryIter()) && SRAstr(t) == s; } /* If one of t's children is an SR/OSR access whose codeString equals s then return true and set result to the index of that child. (If e is true then restrict the search to expressions that appear on every iteration.) Otherwise return false and do not modify result. */ static bool containsTopLevelSRA(TreeNode *t, const string &s, bool e, int &result) { int arity = t->arity(); for (int i = 0; i < arity; i++) if (isMatchingSRA(t->child(i), s, e)) { result = i; return true; } return false; } /* Find a use of SRA s in t. Insert in t a statement to copy that use into a new temporary. Put a decl for that new temporary in decl, and an assignment to reg (copying it) in final_assn. decl should be NULL upon entry and will become non-NULL iff this function is successful. */ static TreeNode *saveSRAinReg(TreeNode *t, const string &s, TreeNode *reg, TreeNode *&decl, TreeNode *&final_assn) { int i, arity = t->arity(); if (t->isStatementNode() && containsTopLevelSRA(t, s, true, i) || isExpressionStmtNode(t) && containsTopLevelSRA(t->child(0), s, true, i)) { if (DEBUG_OPT) cout << "saveSRAinReg() working on child " << i << " of " << pseudocode(t) << endl; TreeNode *val; // the value to copy into reg if (isExpressionStmtNode(t)) { if (isAssignNode(t->child(0))) /* If SRA is RHS of an assignment then copy the LHS; if SRA is LHS of an assignment then copy the RHS */ val = t->child(0)->child(1 - i); else /* Non-assignment: redo the load (rare) */ val = t->child(0)->child(i); } else /* Non-assignment: redo the load (rare) */ val = t->child(i); TreeNode *assn; ObjectNode *w = MakeTemporary(CloneTree(val), decl, assn); final_assn = new ExpressionStmtNode(new AssignNode(reg, w, val->position()), val->position()); return new BlockNode(cons(t, cons(assn)), NULL, t->position()); } else for (i = arity; decl == NULL && --i >= 0; ) t->child(i, saveSRAinReg(t->child(i), s, reg, decl, final_assn)); return t; } static TreeNode *replaceSRA(TreeNode *t, const string &s, TreeNode *reg, bool rvalOnly = false) { if (!(rvalOnly && isLHSofAssignNode(t)) && isMatchingSRA(t, s, false)) return CloneTree(reg); else for (int i = t->arity(); --i >= 0; ) t->child(i, replaceSRA(t->child(i), s, reg, rvalOnly)); return t; } static TreeNode *replaceNonLvalSRA(TreeNode *t, const string &s, TreeNode *reg) { return replaceSRA(t, s, reg, true); } //////////////////////////////////////////////////////////////////////////// void UpdatePointBeforeStmtNode::emitStatement(CodeContext &context) { ForEachStmtNode *f = static_cast(WRTloop()); if (DEBUG_UPDATEPOINTBEFORESTMTNODE) { cout << endl << "UPBS Before:" << endl; stmt()->pseudoprint(cout); } /* Apply patches implied by urs, rfrs, and sirs */ foreach (ur, llist, *urs()) { TreeNode *reg = (*ur).first, *acc = (*ur).second; /* here: this to be removed */ if (reg == NULL && acc == NULL) continue; if (DEBUG_UPDATEPOINTBEFORESTMTNODE) cout << "ur patch: use " << pseudocode(reg) << " for " << SRAstr(acc) << endl; stmt(replaceSRA(stmt(), SRAstr(acc), reg)); } { llist *sir_decls = NULL, *sir_finals = NULL; foreach (sir, llist, *sirs()) { TreeNode *reg = (*sir).first, *acc = (*sir).second, *decl = NULL, *final_assn; stmt(saveSRAinReg(stmt(), SRAstr(acc), reg, decl, final_assn)); assert(decl != NULL); push(sir_decls, decl); push(sir_finals, final_assn); if (DEBUG_UPDATEPOINTBEFORESTMTNODE) cout << "sir patch: " << pseudocode(reg) << " saves " << SRAstr(acc) << endl; } if (sir_decls != NULL) stmt(new BlockNode(extend(sir_decls, cons(stmt(), sir_finals)), NULL)); } foreach (rfr, llist, *rfrs()) { TreeNode *reg = (*rfr).first, *acc = (*rfr).second; if (DEBUG_UPDATEPOINTBEFORESTMTNODE) cout << "rfr patch: read " << pseudocode(reg) << " instead of " << SRAstr(acc) << endl; stmt(replaceNonLvalSRA(stmt(), SRAstr(acc), reg)); } if (DEBUG_UPDATEPOINTBEFORESTMTNODE) { cout << endl << "UPBS After:" << endl; stmt()->pseudoprint(cout); } CodeContext subcontext(context); stmt()->codeGen(subcontext); // side effect: sets f->needPoint() const bool needPoint = f->needPoint(); if (!needPoint) context << endCline << "#if 0" << endCline; int ln = loopNumber(f), arity = f->tiArity(); TreeNode *var = f->vars()->child(0)->simpName(); const string &iterationVar = *(var->ident()); const string mangledIV = MANGLE_STACK_VAR(+, iterationVar); for (int i = 0; i < arity; i++) if (pointValues()[i] != NULL) context << pointField(mangledIV, i, arity) << (valuesAreDeltas() ? " += " : " = ") << *(pointValues()[i]) << ';' << endCline; if (!needPoint) context << endCline << "#endif" << endCline; /* Update SR pointers */ if (mapLoopToBases != NULL) { set &s = (*mapLoopToBases)[ln]; for (set::iterator base = s.begin(); base != s.end(); base++) { string at0 = atZero(*base); string ptr = (*mapBaseToUse)[*base]; // the pointer to be updated map &m = (*mapLoopBaseDimToDiff)[ln][*base]; context << ptr << (valuesAreDeltas() ? " +=" : (" = " + at0 + " + ")) << " ("; bool needPlus = false; for (int i = 0; i < arity; i++) if (pointValues()[i] != NULL) { if (needPlus) context << " + "; else needPlus = true; context << *(pointValues()[i]) << " * " << m[i]; } context << ");" << endCline; } } } bool & UpdatePointBeforeStmtNode::valuesAreDeltas() { return _valuesAreDeltas; } TreeNode *ForEachStmtNode::useOfArray(TreeNode *vardecl) const { TreeNode *t = arrayAccesses->use(vardecl); if (t != NULL) t = t->array(); return t; } void ForEachStmtNode::emitSetup(CodeContext &context) { static int unique_label; if (loopNumber(this) == 0) resetTempCounter(&unique_label); const string label = "setup_bailout" + int2string(unique_label++); /* Emit a foreach loop that, as its body, has a goto to the above label. Thus, we get all the strength reduction setup and so on, but do no iterations of the loop. */ TreeNode *saved = stmt(); stmt(new CodeLiteralNode("goto " + label + ';')); context << "/* foreach setup */" << endCline; emitStatement(context); context << "/* end foreach setup */" << endCline; context << label << ": ;" << endCline; delete stmt(); stmt(saved); } /* For each SR pointer, generate an additional variable that is equal to the address to which that pointer would be set if the iteration point were 0. */ static void generateExtraSRsetup(ForEachStmtNode *l, CodeContext &context) { int ln = loopNumber(l), arity = l->tiArity(); if (mapLoopToBases != NULL) { set &s = (*mapLoopToBases)[ln]; for (set::iterator base = s.begin(); base != s.end(); base++) { const string at0 = atZero(*base), &q = (*mapLoopToQ)[ln]; PtrType &p = (*mapLoopBaseToPtrType)[ln][*base]; context.declare(at0, ctBox(p.first->cType(), p.second)); context << at0 << " = " << *base; map &m = (*mapLoopBaseDimToDiff)[ln][*base]; for (int i = 0; i < arity; i++) context << " - " << indexCurrent(q, i) << " * " << m[i]; context << ';' << endCline; } } } void ForEachSetupNode::emitStatement(CodeContext &context) { ForEachStmtNode *f = static_cast(WRTloop()); CodeContext *e = enclosingDeclContainer(this); CodeContext subcontext(context, *e); f->blockContext() = e; f->shouldDeclare() = f->saveSRinfo() = true; f->emitSetup(subcontext); generateExtraSRsetup(f, subcontext); f->saveSRinfo() = false; } void StrippedForEachNode::emitStatement(CodeContext &context) { ForEachStmtNode *f = static_cast(stmt()); CodeContext *e = enclosingDeclContainer(this); CodeContext subcontext(context, *e); f->blockContext() = e; f->shouldDeclare() = false; f->emitStatement(subcontext); } CodeContext *& ForEachStmtNode::blockContext() { return _blockContext; } void ForEachStmtNode::emitStatement(CodeContext &context) { string q = string("f") + int2string(loopNumber(this)); TreeNode *var = vars()->child(0)->simpName(); TreeNode *dom = vars()->child(0)->initExpr(); int a = tiArity(); bool isGeneralDomain = !dom->absent() && dom->type()->isDomainType(); reset_buildexpr(); context << endCline << "/* foreach */" << endCline; if (saveSRinfo()) saveSRq(this, q); { CodeContext subcontext(context, blockContext()); string curr_rect; if (isGeneralDomain) { foreachGeneralStart(subcontext, a, q, dom, shouldDeclare()); curr_rect = q + "_curr_rect"; } else { foreachRectDomainStart(subcontext, a, q, dom, this, curr_rect); } { CodeContext bodyContext(subcontext, blockContext()); if (stride() != NULL) foreachRangeIterate(bodyContext, a, q, lo(), hi(), stride(), var, stmt(), needPoint(), *SR, arrayAccesses, this, partialDomain(), shouldDeclare(), saveSRinfo()); else foreachRectIterate(bodyContext, a, q, dom, var, stmt(), needPoint(), curr_rect, cannotBeEmpty() == NULL, *SR, arrayAccesses, this, partialDomain(), shouldDeclare(), saveSRinfo()); /* If needed, declare the iteration point. */ if (needPoint()) bodyContext.declare(MANGLE_STACK_VAR(+, *(var->ident())), makePointType(a)->cType()); } if (isGeneralDomain) foreachGeneralEnd(subcontext, a, q); } context << "/* hcaerof */" << endCline; } TreeNode *ForEachStmtNode::cannotBeEmpty() const { return _cannotBeEmpty; } TreeNode *& ForEachStmtNode::cannotBeEmpty() { return _cannotBeEmpty; } int ForEachStmtNode::tiArity() const { return (stride() == NULL) ? vars()->child(0)->initExpr()->type()->tiArity() : 1; } bool ForEachStmtNode::partialDomain() const { return _partialDomain; } int & ForEachStmtNode::numLifted() { return _numLifted; } bool & ForEachStmtNode::needPoint() { return _needPoint; } bool & ForEachStmtNode::tentative() { return _tentative; } bool & ForEachStmtNode::ordered() { return _ordered; } ExprNode *& ForEachStmtNode::lo() { return _lo; } ExprNode *& ForEachStmtNode::hi() { return _hi; } ExprNode *& ForEachStmtNode::stride() { return _stride; } bool ForEachStmtNode::tentative() const { return _tentative; } bool ForEachStmtNode::ordered() const { return _ordered; } ExprNode *ForEachStmtNode::lo() const { return _lo; } ExprNode *ForEachStmtNode::hi() const { return _hi; } ExprNode *ForEachStmtNode::stride() const { return _stride; } void ForEachStmtNode::deepCloneSpecial (TreeNode* copy0) const { ForEachStmtNode* copy = (ForEachStmtNode*) copy0; if (_lo != NULL && ! _lo->absent ()) copy->_lo = static_cast(_lo->deepClone ()); if (_hi != NULL && ! _hi->absent ()) copy->_hi = static_cast(_hi->deepClone ()); if (_stride != NULL && ! _stride->absent ()) copy->_stride = static_cast(_stride->deepClone ()); } bool ForEachStmtNode::setParallel(bool b) { return ((_parallel = b)); } bool ForEachStmtNode::getParallel() const { return _parallel; } bool & ForEachStmtNode::saveSRinfo() { return _saveSRinfo; } bool & ForEachStmtNode::shouldDeclare() { return _shouldDeclare; } /* Some discussion of the non-integral increment problem is below. At the moment, we only try to optimize index expressions that must appear on every iteration of the loop, so it isn't an issue. Date: Fri, 30 May 1997 15:17:29 -0700 (PDT) From: Geoff Pike To: titanium-group@cs.Berkeley.EDU In-reply-to: <338E6139.6601283B@cs.berkeley.edu> (message from Alex Aiken on Thu, 29 May 1997 22:10:17 -0700) Subject: foreach, foreach, foreach Further analysis of foreach over RectDomains. There are two efficiency issues: bounds checks and non-integral increments. By deciding the bounds checking errors are ERRORS and deciding that for efficiency we may compile under the assumption that there are no out-of-bounds array references, much of the mess of two weeks ago is gone. What are the cases we can handle optimally? In short, the cases where every index expression A[e] that appears in the loop body is known to be valid in every iteration of the loop. Or A[e] simply appears in every iteration and is assumed to be valid because bounds checking is turned off. Or A[e] appears in every iteration and is proven to be valid in the loop header as part of bounds checking. Also, a few other cases. Notation: p is a point (the control point of the iteration) R and Q are RectDomains c, d, and k are loop invariant points or integers A and B are Titanium arrays (assume 1D for simplicity) 1a. foreach (p in R) { ... A[p] ... } where we know R <= A.domain (or foreach (p in Q) { ... A[c*p+d] ... } where c*Q+d <= A.domain) If you want the performance guarantee, one way to make sure you get it is to only use A[p] for a p running over a domain known to be a subset of A.domain. Everything one can do with array foreach falls into this category, because "array foreach (f in A) ... f ..." is equivalent to "foreach (p in A.domain) ... A[p] ..." Why is it fast? We know A[p] is valid for every p, hence it is safe and bounds checking is unnecessary. Furthermore, &A[p] is integral if A[p] is valid, so the memory increment used in strength reduction is the difference of two valid addresses, an integer. This is the big winner. It is trivial to have programmers understand it. Also, my feeling is that 99% of loops will fall in this category or 1b. HOWEVER, THE PROGRAMMER MAY HAVE TO NUDGE THE COMPILER TO MAKE SURE IT KNOWS WHAT IT NEEDS TO KNOW. The things listed below are gravy (and increasing obscure). 1b. foreach (p in R) { ... A[c*p+d] ...} // c and d are loop invariant where A[c*p+d] occurs in all iterations (In these examples we need not know a thing about R and A.) Example 1: foreach (p in R) { A[p] = 7; } Example 2*: foreach (p in R) { if (foo()) A[p] = 7; else A[p+3] = 9; } Example 3: foreach (p in R) { A[k*p]; if (foo()) A[k*p] = 7; } Intuition: Assume the program does not have bounds checking errors (i.e., bounds checking is turned off or we checked it somehow). If A[p] is used in every iteration it is valid in every iteration and non-integral memory increments are unnecessary. Example 2* may be a problem. Example 3 is OK because A[k*p] is appears (and hence assumed to be valid) in every iteration. Note that the programmer might write this way just to be extra sure that the compiler sees &A[k*p] as an induction variable with integral increment. If bounds checking is turned off this is really just like case 1a, i.e., the compiler knows that the domain of access is a RectDomain and assumes it is a subset of the array's domain because it assumes no bounds checking errors. 2. foreach (p in R) { ... A[p] ...} where (stride of R) times (memory stride of A) times (elt size of A) is an integer This is a subset of the dangerous cases that happen to be no problem because the granularity of pointers in the hardware is often 4x or 8x different from the granularity of addresses of interest. Example 4: in memory we have ----------------------------------------- ... | A[0] | A[2] | A[4] | A[6] | ... (memory stride = 1:2) ----------------------------------------- int i = 0; foreach (p in [lo : hi]) { if (bar(i++)) A[p] = 11.0; } In this example, we can treat A[p] as an induction variable that is incremented by (memory stride) * (stride of [lo : hi], namely, 1) * (elt size) each iteration, which comes out to (elt size)/2, which is known at compile time to be an integer. Also, sometimes it may be possible to infer that an increment much be integral even without exact knowledge of the strides. Compare this to example 3: Example 5: foreach (p in R) { A[p] = sqrt(bar()); if (foo()) A[k*p] = 7; } Here we know that A[p] is always valid, so it's increment is integral. But A[k*p]'s increment is just k times A[p]'s increment, so it too is integral. A similar case is where the memory stride of A is somehow known to be an integer, but we know not what it is. That's good enough, though. ----------------------------------------------------------------------------- Now, the bad cases. Example 6*: foreach (p in R) { A[floor(p*sin(p[2]))] = 11; } This will compile into something at best equivalent to what C or FORTRAN would do, and probably worse because we won't be able to remove the division from the address calculation. Example 2*: foreach (p in R) { if (foo()) A[p] = 7; else A[p+3] = 9; } A[p] and A[p+3] can be treated as induction variables, but they may in general have non-integral increments. This is the problem pointed out by David. Let's simplify the code to just: Example 7*: foreach (p in R) { if (foo()) A[p] = 7; } and ask how to compile it. One solution is not to bother treating &A[p] as an induction variable. Then we do the full address calculation (in general including some multiplies and adds and at least one divide) each time foo() is true. The other solutions treat &A[p] as an induction variable. Then there are two relevant subcases: i. We don't know whether the increment is integral ii. We know it is not integral ( iii. We know it is integral --- see "good cases" above) For the first case I propose we'll generate something like: { int index = ..., index_stop = ...; if (index < index_stop) { int index_inc = ..., w, inc, when_inc; if (...) { when_inc = ...; w = ...; inc = ...; } // sad else { when_inc = 1; w = 0; inc = ...; } // happy T *_a = ...; do { if (foo()) *_a = 7; index += index_inc; if (!(index < index_stop)) break; if (++w == when_inc) { w = 0; _a += inc; } // or, the last line could be replaced with: // if (happy || (++w == when_inc && ((w = 0), true))) _a += inc } while (true); } } In case (ii) we'll generate similar but slightly simpler code. By treating &A[p] as an induction variable we pay a few extra cycles per iteration instead of paying a bunch of extra cycles per use. Geoff */