// Resolve operator overloading, rewrite as method invocations #include "AST.h" #include "decls.h" #include "strings.h" TreeNode *TreeNode::generalOverload(string op, bool assign, TreeNode *object, TreeListNode *args) { TypeNode *otype = object->type(); bool rewroteRHSOpOverload = false; // DOB: PR619 - operator overloading for "PrimitiveExpr X ClassExpr" // rewrite to: ClassExpr.opX(PrimitiveExpr, null) // (null gets replaced with a copy of the value resulting from ClassExpr during lowering) if (otype->isPrimitive() && args->arity() == 1 && args->child(0)->type()->kind() & (ClassKind | InterfaceKind | ImmutableKind)) { TreeNode *primarg = object; object = args->child(0); otype = object->type(); ExprNode *nullarg = new NullPntrNode(); nullarg->type(otype->deepClone()); llist *tmp = cons(primarg, cons((TreeNode*)nullarg)); args = new TreeListNode(tmp); rewroteRHSOpOverload = true; } if (!(otype->kind() & (ClassKind | InterfaceKind | ImmutableKind))) return this; const string *opString = intern(op); EnvironIter methods = otype->decl()->environ() ->lookupFirstProper(opString, Decl::Method); if ((*opString == "==" || *opString == "!=") && args->arity() == 1 && ( (otype->isDomainType() && args->child(0)->type()->isDomainType()) || (otype->isRectDomainType() && args->child(0)->type()->isDomainType()) || (otype->isDomainType() && args->child(0)->type()->isRectDomainType()) )) { // Allow Object == Domain or Object == Domain without warning (to force ptr compare without warning) warning("domain-equals") << "Use of ==/!= on a Domain type now yields pointer (in)equality comparison - use Domain.equals() instead to compare Domain contents." << endl; } if (methods.isDone()) return this; // leave error to type checker SourcePosn p = position(); Decl *method = resolveCall(methods, otype->modifiers(), args, p); // build a method call node TreeNode *methodAccess = new ObjectFieldAccessNode(object, new NameNode(TreeNode::omitted, opString, method, p), p); TreeNode *methodNode; if (assign) { MethodCallAssignNode *tmp = new MethodCallAssignNode(methodAccess, NULL, p); tmp->isRewrittenRHSOpOverload(rewroteRHSOpOverload); methodNode = tmp; } else { MethodCallNode *tmp = new MethodCallNode(methodAccess, NULL, p); tmp->isRewrittenRHSOpOverload(rewroteRHSOpOverload); methodNode = tmp; } methodNode->args(args); return methodNode; } bool MethodCallNode::isRewrittenRHSOpOverload() const { return _isRewrittenRHSOpOverload; } void MethodCallNode::isRewrittenRHSOpOverload(bool val) { _isRewrittenRHSOpOverload = val; } bool MethodCallAssignNode::isRewrittenRHSOpOverload() const { return _isRewrittenRHSOpOverload; } void MethodCallAssignNode::isRewrittenRHSOpOverload(bool val) { _isRewrittenRHSOpOverload = val; } TreeNode *TreeNode::resolveOverload(bool assign) { llist *children = arity() == 1 ? NULL : cons(child(1)); TreeListNode *methodArgs = new TreeListNode(children, position()); return generalOverload(operatorName(), assign, child(0), methodArgs); } TreeNode *TreeNode::overloadArray() { return generalOverload("[]", false, array(), index()->args()); } TreeNode *TreeNode::overloadArrayAssign() { llist *mlArgs = appendTreeList(opnd0()->index()->args(), cons(opnd1())); TreeListNode *methodArgs = new TreeListNode(mlArgs, position()); return generalOverload("[]=", false, opnd0()->array(), methodArgs); } TreeNode *TreeNode::_resolveOperators() { return this; } TreeNode *ComplementNode::_resolveOperators() { return resolveOverload(false); } TreeNode *NotNode::_resolveOperators() { return resolveOverload(false); } TreeNode *UnaryArithNode::_resolveOperators() { return resolveOverload(false); } TreeNode *BinaryArithNode::_resolveOperators() { return resolveOverload(false); } TreeNode *PlusNode::_resolveOperators() { return resolveOverload( false ); } TreeNode *ShiftNode::_resolveOperators() { return resolveOverload(false); } TreeNode *RelationNode::_resolveOperators() { return resolveOverload(false); } TreeNode *EqualityNode::_resolveOperators() { return resolveOverload(false); } TreeNode *BitwiseNode::_resolveOperators() { return resolveOverload(false); } TreeNode *BinaryArithAssignNode::_resolveOperators() { return resolveOverload(true); } TreeNode *ShiftAssignNode::_resolveOperators() { return resolveOverload(true); } TreeNode *BitwiseAssignNode::_resolveOperators() { return resolveOverload(true); } // operatorName (this is used for more than just operator overloading, so // it is defined on all operators) const char *TreeNode::operatorName() const { undefined("operatorName"); return NULL; } const char *ArrayAccessNode::operatorName() const { return "[]"; } const char *ComplementNode::operatorName() const { return "~"; } const char *NotNode::operatorName() const { return "!"; } const char *PostIncrNode::operatorName() const { return "++"; } const char *PostDecrNode::operatorName() const { return "--"; } const char *PreIncrNode::operatorName() const { return "++"; } const char *PreDecrNode::operatorName() const { return "--"; } const char *UnaryPlusNode::operatorName() const { return "+"; } const char *UnaryMinusNode::operatorName() const { return "-"; } const char *MultNode::operatorName() const { return "*"; } const char *DivNode::operatorName() const { return "/"; } const char *RemNode::operatorName() const { return "%"; } const char *PlusNode::operatorName() const { return "+"; } const char *MinusNode::operatorName() const { return "-"; } const char *LeftShiftLogNode::operatorName() const { return "<<"; } const char *RightShiftLogNode::operatorName() const { return ">>>"; } const char *RightShiftArithNode::operatorName() const { return ">>"; } const char *LTNode::operatorName() const { return "<"; } const char *GTNode::operatorName() const { return ">"; } const char *LENode::operatorName() const { return "<="; } const char *GENode::operatorName() const { return ">="; } const char *EQNode::operatorName() const { return "=="; } const char *NENode::operatorName() const { return "!="; } const char *BitAndNode::operatorName() const { return "&"; } const char *BitOrNode::operatorName() const { return "|"; } const char *BitXorNode::operatorName() const { return "^"; } const char *CandNode::operatorName() const { return "&&"; } const char *CorNode::operatorName() const { return "||"; } const char *IfExprNode::operatorName() const { return "?:"; } const char *AssignNode::operatorName() const { return "="; } const char *PlusAssignNode::operatorName() const { return "+="; } const char *MultAssignNode::operatorName() const { return "*="; } const char *DivAssignNode::operatorName() const { return "/="; } const char *RemAssignNode::operatorName() const { return "%="; } const char *MinusAssignNode::operatorName() const { return "-="; } const char *LeftShiftLogAssignNode::operatorName() const { return "<<="; } const char *RightShiftLogAssignNode::operatorName() const { return ">>>="; } const char *RightShiftArithAssignNode::operatorName() const { return ">>="; } const char *BitAndAssignNode::operatorName() const { return "&="; } const char *BitOrAssignNode::operatorName() const { return "|="; } const char *BitXorAssignNode::operatorName() const { return "^="; } const char *StringConcatNode::operatorName() const { return "string concatenation"; } const char *StringConcatAssignNode::operatorName() const { return "assigned string concatenation"; } const char *PointNode::operatorName() const { return "point expressions"; } const char *DomainNode::operatorName() const { return "domain expressions"; } const char *InstanceOfNode::operatorName() const { return "instanceof"; } const char *CastNode::operatorName() const { return "cast"; } // Operator names for statements (for error messages only) const char *WhileNode::operatorName() const { return "while"; } const char *DoNode::operatorName() const { return "do-while"; } const char *ForNode::operatorName() const { return "for"; } const char *SwitchNode::operatorName() const { return "switch"; } const char *SwitchBranchNode::operatorName() const { return "switch-branch"; } const char *ReturnNode::operatorName() const { return "return"; } const char *AssertNode::operatorName() const { return "assert"; } const char *IfStmtNode::operatorName() const { return "if"; } const char *ThrowNode::operatorName() const { return "throw"; } const char *SynchronizedNode::operatorName() const { return "synchronized"; } const char *CatchNode::operatorName() const { return "try-catch"; } const char *TryStmtNode::operatorName() const { return "try"; } const char *ForEachStmtNode::operatorName() const { return "foreach"; } const char *PartitionClauseNode::operatorName() const { return "partition clauses"; } const char *BroadcastNode::operatorName() const { return "broadcast"; }