//===--- LiteralExpressionFolding.cpp - -------------------------*- C++ -*-===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2026 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // Simple AST-based evaluator of supported literal expressions // //===----------------------------------------------------------------------===// #include "LiteralExpressionFolding.h" #include "MiscDiagnostics.h" #include "swift/AST/ASTContext.h" #include "swift/AST/ASTWalker.h" #include "swift/AST/DiagnosticsSema.h" #include "swift/AST/TypeCheckRequests.h" #include "swift/Basic/Unreachable.h" using namespace swift; using namespace LiteralExprFolding; namespace { class FoldingError : public llvm::ErrorInfo { public: static char ID; IllegalConstError code; SourceLoc sourceLocation; FoldingError(IllegalConstError code) : code(code), sourceLocation() {} FoldingError(IllegalConstError code, SourceLoc loc) : code(code), sourceLocation(loc) {} void log(llvm::raw_ostream &OS) const override { OS << "Const Folding Error: " << static_cast::type>(code); } std::error_code convertToErrorCode() const override { return llvm::inconvertibleErrorCode(); } }; char FoldingError::ID = 0; static unsigned getTargetPointerBitWidth(ASTContext &ctx) { // Matching the compiler's determination of the // `#if _pointerBitWidth` value. if (ctx.LangOpts.Target.isArch16Bit()) return 16; if (ctx.LangOpts.Target.isArch32Bit()) return 32; if (ctx.LangOpts.Target.isArch64Bit()) return 64; swift_unreachable("Unsupported platform kind."); } /// Get the bitwidth for a Swift integer type for constant folding. /// Returns 0 if the type is not a known integer type. static unsigned getIntegerBitWidth(Type type, ASTContext &ctx) { assert(type->isStdlibInteger()); // Map stdlib integer types to their bitwidths if (type->isInt()) return getTargetPointerBitWidth(ctx); if (type->isInt128()) return 128; if (type->isInt64()) return 64; if (type->isInt32()) return 32; if (type->isInt16()) return 16; if (type->isInt8()) return 8; if (type->isUInt()) return getTargetPointerBitWidth(ctx); if (type->isUInt128()) return 128; if (type->isUInt64()) return 64; if (type->isUInt32()) return 32; if (type->isUInt16()) return 16; if (type->isUInt8()) return 8; swift_unreachable("Unsupported integer type."); } static ConcreteDeclRef getIntTypeBuiltinInit(Type type, ASTContext &ctx) { assert(type->isStdlibInteger()); // Map stdlib integer types to their bitwidths if (type->isInt()) return ctx.getIntBuiltinInitDecl(ctx.getIntDecl()); if (type->isInt128()) return ctx.getIntBuiltinInitDecl(ctx.getInt128Decl()); if (type->isInt64()) return ctx.getIntBuiltinInitDecl(ctx.getInt64Decl()); if (type->isInt32()) return ctx.getIntBuiltinInitDecl(ctx.getInt32Decl()); if (type->isInt16()) return ctx.getIntBuiltinInitDecl(ctx.getInt16Decl()); if (type->isInt8()) return ctx.getIntBuiltinInitDecl(ctx.getInt8Decl()); if (type->isUInt()) return ctx.getIntBuiltinInitDecl(ctx.getUIntDecl()); if (type->isUInt128()) return ctx.getIntBuiltinInitDecl(ctx.getUInt128Decl()); if (type->isUInt64()) return ctx.getIntBuiltinInitDecl(ctx.getUInt64Decl()); if (type->isUInt32()) return ctx.getIntBuiltinInitDecl(ctx.getUInt32Decl()); if (type->isUInt16()) return ctx.getIntBuiltinInitDecl(ctx.getUInt16Decl()); if (type->isUInt8()) return ctx.getIntBuiltinInitDecl(ctx.getUInt8Decl()); swift_unreachable("Unsupported integer type."); } /// Check if the type is a signed integer type static bool isSignedIntegerType(Type type) { return type->isInt() || type->isInt128() || type->isInt64() || type->isInt32() || type->isInt16() || type->isInt8(); } /// A move-only value that's either a `FoldingError` or some parameterized type /// value `T`. A convenience wrapper around `TaggedUnion`. template class FoldingErrorOr { TaggedUnion Value; public: FoldingErrorOr() : Value(FoldingError(IllegalConstError::Default)) {} FoldingErrorOr(T &&t) : Value(std::move(t)) {} FoldingErrorOr(const FoldingError &fe) : Value(fe) {} FoldingErrorOr(FoldingErrorOr &&other) : Value(std::move(other.Value)) {} FoldingErrorOr &operator=(FoldingErrorOr &&other) noexcept { if (this != &other) Value = std::move(other.Value); return *this; } FoldingErrorOr(const FoldingErrorOr &) = delete; FoldingErrorOr &operator=(const FoldingErrorOr &) = delete; const T *operator->() const { return Value.template dyn_cast(); } FoldingError getError() const { return *Value.template dyn_cast(); } bool isError() const { return Value.template dyn_cast() != nullptr; } /// Return false if there is an error. explicit operator bool() const { return !isError(); } }; class ConstantValue { public: enum class ConstantValueKind : int8_t { FirstKind, Integer = FirstKind, FloatingPoint, LastKind = FloatingPoint + 1 }; const ConstantValueKind kind; ConstantValue(ConstantValueKind kind) : kind(kind) {} virtual ~ConstantValue() = default; ConstantValueKind getKind() const { return kind; } }; using ConstantValuePtr = std::unique_ptr; class IntegerValue : public ConstantValue { APInt value; bool isSigned; public: IntegerValue(APInt value, bool isSigned) : ConstantValue(ConstantValueKind::Integer), value(value), isSigned(isSigned) {} APInt getValue() const { return value; } bool getIsSigned() const { return isSigned; } static bool classof(const ConstantValue *base) { return base->getKind() == ConstantValueKind::Integer; } }; /// A simple constant expression folder to simplify /// binary expressions of integer type consisting of literal /// value operands. class ConstantFolder { ASTContext &Ctx; public: ConstantFolder(ASTContext &ctx) : Ctx(ctx) {} Expr *fold(const Expr *expr) { // If this expression failed to type-check, no need to attempt to // fold it since we likely won't be able to do anything meaningful // here. if (!expr->getType() || expr->getType()->getAs()) { emitFoldingErrorDiagnostic( FoldingError(IllegalConstError::UpstreamError, expr->getStartLoc())); return nullptr; } ConstantWalker walker(Ctx); const_cast(expr)->walk(walker); ASSERT(walker.hasConstantValueFor(expr) && "No value or error computed by constant-folding AST walker"); const auto &result = walker.getConstantValueOrErrorFor(expr); if (result) return createIntegerLiteralExpr(expr, result->get()); else { emitFoldingErrorDiagnostic(result.getError()); return nullptr; } } private: class ConstantWalker : public ASTWalker { ASTContext &Ctx; llvm::DenseMap> ConstValuesOrErrors; public: ConstantWalker(ASTContext &ctx) : Ctx(ctx) {} PostWalkResult walkToExprPost(Expr *expr) override { ConstValuesOrErrors.insert({expr, tryFoldExpression(expr)}); return Action::Continue(expr); } bool hasConstantValueFor(const Expr *expr) { return ConstValuesOrErrors.contains(expr); } const FoldingErrorOr & getConstantValueOrErrorFor(const Expr *expr) { ASSERT(ConstValuesOrErrors.contains(expr) && "Querying constant value for an unfolded expression."); return ConstValuesOrErrors.at(expr); } private: FoldingErrorOr tryFoldExpression(const Expr *expr) { if (auto *literalExpr = dyn_cast(expr)) return tryFoldLiteralExpression(literalExpr); if (auto *binaryExpr = dyn_cast(expr)) return tryFoldBinaryExpr(binaryExpr); if (auto *unaryExpr = dyn_cast(expr)) return tryFoldUnaryExpr(unaryExpr); if (auto *parenExpr = dyn_cast(expr)) return tryFoldParenExpr(parenExpr); if (auto *declRefExpr = dyn_cast(expr)) return foldDeclRefExpr(declRefExpr); return FoldingError(IllegalConstError::Default, expr->getLoc()); } FoldingErrorOr tryFoldLiteralExpression(const LiteralExpr *expr) { if (auto *intLiteralExpr = dyn_cast(expr)) return foldIntegerLiteralExpr(intLiteralExpr); return FoldingError(IllegalConstError::Default, expr->getLoc()); } ConstantValuePtr foldIntegerLiteralExpr(const IntegerLiteralExpr *expr) { auto exprType = expr->getType(); auto value = expr->getValue(); auto resultBitWidth = getIntegerBitWidth(exprType, Ctx); if (isSignedIntegerType(exprType)) return std::make_unique(value.sextOrTrunc(resultBitWidth), true); return std::make_unique(value.zextOrTrunc(resultBitWidth), false); } FoldingErrorOr tryFoldBinaryExpr(const BinaryExpr *expr) { if (!expr->getType()->isStdlibInteger()) return FoldingError(IllegalConstError::TypeNotSupported, expr->getStartLoc()); if (!supportedOperator(expr)) return FoldingError(IllegalConstError::UnsupportedBinaryOperator, expr->getStartLoc()); if (const auto &lhsOrErr = getConstantValueOrErrorFor(expr->getLHS())) { if (const auto &rhsOrErr = getConstantValueOrErrorFor(expr->getRHS())) { auto lhsIntPtr = cast(lhsOrErr->get()); auto rhsIntPtr = cast(rhsOrErr->get()); auto operatorDecl = expr->getCalledValue(); auto operatorIdentifier = operatorDecl->getBaseName().getIdentifier(); if (operatorIdentifier.isArithmeticOperator()) return tryFoldIntegerBinaryArithmeticOperator( operatorIdentifier, expr->getLoc(), lhsIntPtr, rhsIntPtr); if (operatorIdentifier.isOverflowArithmeticOperator()) return tryFoldIntegerBinaryOverflowArithmeticOperator( operatorIdentifier, expr->getLoc(), lhsIntPtr, rhsIntPtr); if (operatorIdentifier.isBitwiseOperator() || operatorIdentifier.isShiftOperator()) return tryFoldIntegerBinaryBitwiseOperator( operatorIdentifier, expr->getLoc(), lhsIntPtr, rhsIntPtr); if (operatorIdentifier.isMaskingShiftOperator()) return tryFoldIntegerBinaryMaskingShiftOperator( operatorIdentifier, expr->getLoc(), lhsIntPtr, rhsIntPtr); llvm_unreachable("Unsupported operator"); } else return rhsOrErr.getError(); } else return lhsOrErr.getError(); } FoldingErrorOr tryFoldUnaryExpr(const PrefixUnaryExpr *expr) { if (!expr->getType()->isStdlibInteger()) return FoldingError(IllegalConstError::TypeNotSupported, expr->getLoc()); if (!supportedOperator(expr)) return FoldingError(IllegalConstError::UnsupportedBinaryOperator, expr->getLoc()); const auto &operandOrErr = getConstantValueOrErrorFor(expr->getOperand()); if (operandOrErr) { auto operatorIdentifier = expr->getCalledValue()->getBaseName().getIdentifier(); return tryFoldIntegerUnaryArithmeticOperator( operatorIdentifier, expr->getLoc(), cast(operandOrErr->get())); } return operandOrErr.getError(); } FoldingErrorOr tryFoldParenExpr(const ParenExpr *expr) { const auto &operandOrErr = getConstantValueOrErrorFor(expr->getSubExpr()); if (operandOrErr) { auto *intValue = cast(operandOrErr->get()); return ConstantValuePtr(std::make_unique( intValue->getValue(), intValue->getIsSigned())); } return operandOrErr.getError(); } FoldingErrorOr foldDeclRefExpr(const DeclRefExpr *expr) { if (const VarDecl *varDecl = dyn_cast(expr->getDecl())) { // Swift source `let` bindings whose access level is broader than // internal participate in the ABI surface of their module and may // not appear in a literal expression. Emit the diagnostic inline so // we can carry the access level, and return `UpstreamError` so no // further generic "not a literal expression" message is added. // A `var` that reaches here falls through to the existing // opaque-decl-ref path, which is the correct diagnostic for it. if (!varDecl->hasClangNode() && varDecl->isLet()) { auto access = varDecl->getFormalAccess(); if (access >= AccessLevel::Package) { Ctx.Diags.diagnose(expr->getLoc(), diag::const_public_let_ref, access); return FoldingError(IllegalConstError::UpstreamError, expr->getLoc()); } } // For other `@const` or `@section` values, we expect // their initializer to be foldable. For other values which // have a default value, we attempt to fold the // corresponding initializer expression. if (varDecl->isConstValue() || varDecl->hasInitialValue()) if (auto initExpr = varDecl->getParentInitializer()) return tryFoldDeclRefInitializerExpr(initExpr, expr->getLoc()); // Clang constants are imported as a ValueDecl // with simple getter returning a literal value. if (varDecl->hasClangNode() && !varDecl->isDynamic() && !varDecl->isObjC() && varDecl->getImplInfo().getReadImpl() == ReadImplKind::Get) if (auto accessor = varDecl->getAccessor(AccessorKind::Get)) if (auto singleRetStmt = dyn_cast( accessor->getBody()->getSingleActiveStatement())) return tryFoldDeclRefInitializerExpr(singleRetStmt->getResult(), expr->getLoc()); } return FoldingError(IllegalConstError::OpaqueDeclRef, expr->getLoc()); } FoldingErrorOr tryFoldDeclRefInitializerExpr(const Expr *expr, SourceLoc referenceLoc) { bool previouslyFolded = Ctx.evaluator.hasCachedResult(ConstantFoldExpression{expr, &Ctx}); // Request the init expression of this declaration to be // constant-folded. if (auto foldedLiteralExpr = dyn_cast(swift::foldLiteralExpression(expr, &Ctx))) return tryFoldLiteralExpression(foldedLiteralExpr); // If this is the first time we have requested to constant-fold this // declaration's initializer and have failed to do so, emit a note // with a location of the declRef from which we initiated this query. if (!previouslyFolded) return FoldingError(IllegalConstError::NonConstDeclRef, referenceLoc); return FoldingError(IllegalConstError::OpaqueDeclRef, referenceLoc); } FoldingErrorOr tryFoldIntegerBinaryArithmeticOperator( Identifier operatorIdentifier, SourceLoc sourceLocation, const IntegerValue *lhsVal, const IntegerValue *rhsVal) { assert(lhsVal->getIsSigned() == rhsVal->getIsSigned()); bool isSigned = lhsVal->getIsSigned(); auto lhsInt = lhsVal->getValue(); auto rhsInt = rhsVal->getValue(); APInt result; bool overflow = false; if (operatorIdentifier.is("+")) result = isSigned ? lhsInt.sadd_ov(rhsInt, overflow) : lhsInt.uadd_ov(rhsInt, overflow); else if (operatorIdentifier.is("-")) result = isSigned ? lhsInt.ssub_ov(rhsInt, overflow) : lhsInt.usub_ov(rhsInt, overflow); else if (operatorIdentifier.is("*")) result = isSigned ? lhsInt.smul_ov(rhsInt, overflow) : lhsInt.umul_ov(rhsInt, overflow); else if (operatorIdentifier.is("/")) { if (rhsInt == 0) return FoldingError(IllegalConstError::DivideByZero, sourceLocation); result = isSigned ? lhsInt.sdiv_ov(rhsInt, overflow) : lhsInt.udiv(rhsInt); } else if (operatorIdentifier.is("%")) { if (rhsInt == 0) return FoldingError(IllegalConstError::DivideByZero, sourceLocation); if (isSigned) // Check for overflow auto divResult = lhsInt.sdiv_ov(rhsInt, overflow); result = isSigned ? lhsInt.srem(rhsInt) : lhsInt.urem(rhsInt); } else return FoldingError(IllegalConstError::UnsupportedBinaryOperator, sourceLocation); if (overflow) return FoldingError(IllegalConstError::IntegerOverflow, sourceLocation); return ConstantValuePtr(std::make_unique(result, isSigned)); } FoldingErrorOr tryFoldIntegerBinaryOverflowArithmeticOperator( Identifier operatorIdentifier, SourceLoc sourceLocation, const IntegerValue *lhsVal, const IntegerValue *rhsVal) { assert(lhsVal->getIsSigned() == rhsVal->getIsSigned()); bool isSigned = lhsVal->getIsSigned(); auto lhsInt = lhsVal->getValue(); auto rhsInt = rhsVal->getValue(); // APInt's plain +/-/* wrap at the declared bit width, which matches // Swift's overflow arithmetic semantics for &+, &-, &*. APInt result; if (operatorIdentifier.is("&+")) result = lhsInt + rhsInt; else if (operatorIdentifier.is("&-")) result = lhsInt - rhsInt; else if (operatorIdentifier.is("&*")) result = lhsInt * rhsInt; else return FoldingError(IllegalConstError::UnsupportedBinaryOperator, sourceLocation); return ConstantValuePtr(std::make_unique(result, isSigned)); } FoldingErrorOr tryFoldIntegerBinaryBitwiseOperator( Identifier operatorIdentifier, SourceLoc sourceLocation, const IntegerValue *lhsVal, const IntegerValue *rhsVal) { bool isSigned = lhsVal->getIsSigned(); auto lhsInt = lhsVal->getValue(); auto rhsInt = rhsVal->getValue(); // The stdlib shift operators take `where RHS : BinaryInteger`, so the // RHS does not share signedness or bit width with the LHS. The true // bitwise operators `& | ^` do operate on two `Self` values. bool isShift = operatorIdentifier.is("<<") || operatorIdentifier.is(">>"); assert(isShift || lhsVal->getIsSigned() == rhsVal->getIsSigned()); APInt result; if (operatorIdentifier.is("&")) result = lhsInt & rhsInt; else if (operatorIdentifier.is("|")) result = lhsInt | rhsInt; else if (operatorIdentifier.is("^")) result = lhsInt ^ rhsInt; else if (isShift) { // Non-masking shifts trap at runtime when the amount is negative or // greater-or-equal to the LHS bit width. Reject both as folding // errors so the result matches Swift's runtime semantics. The // dedicated diagnostics are emitted inline so they can carry the // amount and bit width; return `UpstreamError` to suppress the // generic "not a literal expression" follow-up. if (rhsVal->getIsSigned() && rhsInt.isNegative()) { Ctx.Diags.diagnose(sourceLocation, diag::const_shift_negative); return FoldingError(IllegalConstError::UpstreamError, sourceLocation); } unsigned width = lhsInt.getBitWidth(); uint64_t amountValue = rhsInt.getLimitedValue(); if (amountValue >= width) { Ctx.Diags.diagnose(sourceLocation, diag::const_shift_out_of_range, static_cast(amountValue), width); return FoldingError(IllegalConstError::UpstreamError, sourceLocation); } unsigned amount = static_cast(amountValue); if (operatorIdentifier.is("<<")) result = lhsInt.shl(amount); else result = isSigned ? lhsInt.ashr(amount) : lhsInt.lshr(amount); } else return FoldingError(IllegalConstError::UnsupportedBinaryOperator, sourceLocation); return ConstantValuePtr(std::make_unique(result, isSigned)); } FoldingErrorOr tryFoldIntegerBinaryMaskingShiftOperator( Identifier operatorIdentifier, SourceLoc sourceLocation, const IntegerValue *lhsVal, const IntegerValue *rhsVal) { // The stdlib masking-shift operators take `where RHS : BinaryInteger`, // so the RHS does not share signedness or bit width with the LHS. bool isSigned = lhsVal->getIsSigned(); auto lhsInt = lhsVal->getValue(); auto rhsInt = rhsVal->getValue(); // Per FixedWidthInteger.&<< / &>>, the shift amount is reduced modulo // the result type's bit width. All Swift integer widths are powers of // two, so urem(width) is equivalent to & (width - 1). unsigned width = lhsInt.getBitWidth(); APInt amount = rhsInt.urem(APInt(rhsInt.getBitWidth(), width)); APInt result; if (operatorIdentifier.is("&<<")) result = lhsInt.shl(amount); else if (operatorIdentifier.is("&>>")) result = isSigned ? lhsInt.ashr(amount) : lhsInt.lshr(amount); else return FoldingError(IllegalConstError::UnsupportedBinaryOperator, sourceLocation); return ConstantValuePtr(std::make_unique(result, isSigned)); } FoldingErrorOr tryFoldIntegerUnaryArithmeticOperator(Identifier operatorIdentifier, SourceLoc sourceLocation, const IntegerValue *operandVal) { APInt result; auto operand = operandVal->getValue(); bool overflow = false; if (operatorIdentifier.is("-")) { APInt zero = APInt(operand.getBitWidth(), 0); result = operandVal->getIsSigned() ? zero.ssub_ov(operand, overflow) : zero.usub_ov(operand, overflow); } else if (operatorIdentifier.is("+")) result = operand; else if (operatorIdentifier.is("~")) result = ~operand; else return FoldingError(IllegalConstError::UnsupportedUnaryOperator, sourceLocation); if (overflow) return FoldingError(IllegalConstError::IntegerOverflow, sourceLocation); return ConstantValuePtr( std::make_unique(result, operandVal->getIsSigned())); } }; Expr *createIntegerLiteralExpr(const Expr *foldedExpr, const ConstantValue *result) { assert(isa(result)); auto intResult = cast(result)->getValue(); auto resultType = foldedExpr->getType(); assert(resultType->isStdlibInteger()); bool isSigned = isSignedIntegerType(resultType); SmallString<32> resultStr; // Get the absolute value for a signed integer // because it is represented as a 'negative value' on the resulting // `IntegerLiteralExpr`. if (isSigned) intResult.abs().toString(resultStr, 10, true); else intResult.toString(resultStr, 10, false); auto *newLit = new (Ctx) IntegerLiteralExpr( Ctx.getIdentifier(resultStr).str(), foldedExpr->getLoc(), /*implicit*/ true); newLit->setType(resultType); newLit->setImplicit(); newLit->setBuiltinInitializer(getIntTypeBuiltinInit(resultType, Ctx)); if (isSigned && intResult.slt(0)) newLit->setNegative(foldedExpr->getLoc()); return newLit; } void emitFoldingErrorDiagnostic(const FoldingError &foldingError) { diagnoseError(foldingError.sourceLocation, foldingError.code, Ctx.Diags); } }; } // anonymous namespace Expr *swift::foldLiteralExpression(const Expr *expr, ASTContext *ctx) { return evaluateOrDefault(ctx->evaluator, ConstantFoldExpression{expr, ctx}, {}); } Expr *ConstantFoldExpression::evaluate(Evaluator &evaluator, const Expr *expr, ASTContext *ctx) const { if (ctx->LangOpts.hasFeature(Feature::LiteralExpressions)) { ConstantFolder folder(*ctx); if (auto result = folder.fold(expr)) return result; } return const_cast(expr); }