//===--- DerivedConformanceComparable.cpp - Derived Comparable -===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements implicit derivation of the Comparable protocol. // (Most of this code is similar to code in `DerivedConformanceEquatableHashable.cpp`) // //===----------------------------------------------------------------------===// #include "TypeChecker.h" #include "swift/AST/Decl.h" #include "swift/AST/Stmt.h" #include "swift/AST/Expr.h" #include "swift/AST/Module.h" #include "swift/AST/Pattern.h" #include "swift/AST/ParameterList.h" #include "swift/AST/ProtocolConformance.h" #include "swift/AST/Types.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/raw_ostream.h" #include "DerivedConformances.h" using namespace swift; static std::pair deriveBodyComparable_enum_uninhabited_lt(AbstractFunctionDecl *ltDecl, void *) { auto parentDC = ltDecl->getDeclContext(); ASTContext &C = parentDC->getASTContext(); auto args = ltDecl->getParameters(); auto aParam = args->get(0); auto bParam = args->get(1); assert(!cast(aParam->getType()->getAnyNominal())->hasCases()); assert(!cast(bParam->getType()->getAnyNominal())->hasCases()); SmallVector statements; auto body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc()); return { body, /*isTypeChecked=*/true }; } /// Derive the body for a '<' operator for an enum that has no associated /// values. This generates code that converts each value to its integer ordinal /// and compares them, which produces an optimal single icmp instruction. static std::pair deriveBodyComparable_enum_noAssociatedValues_lt(AbstractFunctionDecl *ltDecl, void *) { auto parentDC = ltDecl->getDeclContext(); ASTContext &C = parentDC->getASTContext(); auto args = ltDecl->getParameters(); auto aParam = args->get(0); auto bParam = args->get(1); auto enumDecl = cast(aParam->getType()->getAnyNominal()); // Generate the conversion from the enums to integer indices. SmallVector statements; DeclRefExpr *aIndex = DerivedConformance::convertEnumToIndex(statements, parentDC, enumDecl, aParam, ltDecl, "index_a"); DeclRefExpr *bIndex = DerivedConformance::convertEnumToIndex(statements, parentDC, enumDecl, bParam, ltDecl, "index_b"); // Generate the compare of the indices. FuncDecl *cmpFunc = C.getLessThanIntDecl(); assert(cmpFunc && "should have a < for int as we already checked for it"); auto fnType = cmpFunc->getInterfaceType()->castTo(); Expr *cmpFuncExpr; if (cmpFunc->getDeclContext()->isTypeContext()) { auto contextTy = cmpFunc->getDeclContext()->getSelfInterfaceType(); Expr *base = TypeExpr::createImplicitHack(SourceLoc(), contextTy, C); Expr *ref = new (C) DeclRefExpr(cmpFunc, DeclNameLoc(), /*Implicit*/ true, AccessSemantics::Ordinary, fnType); fnType = fnType->getResult()->castTo(); cmpFuncExpr = new (C) DotSyntaxCallExpr(ref, SourceLoc(), base, fnType); cmpFuncExpr->setImplicit(); } else { cmpFuncExpr = new (C) DeclRefExpr(cmpFunc, DeclNameLoc(), /*implicit*/ true, AccessSemantics::Ordinary, fnType); } TupleTypeElt abTupleElts[2] = { aIndex->getType(), bIndex->getType() }; TupleExpr *abTuple = TupleExpr::create(C, SourceLoc(), { aIndex, bIndex }, { }, { }, SourceLoc(), /*HasTrailingClosure*/ false, /*Implicit*/ true, TupleType::get(abTupleElts, C)); auto *cmpExpr = new (C) BinaryExpr( cmpFuncExpr, abTuple, /*implicit*/ true, fnType->castTo()->getResult()); statements.push_back(new (C) ReturnStmt(SourceLoc(), cmpExpr)); BraceStmt *body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc()); return { body, /*isTypeChecked=*/true }; } /// Derive the body for an '==' operator for an enum where at least one of the /// cases has associated values. static std::pair deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, void *) { auto parentDC = ltDecl->getDeclContext(); ASTContext &C = parentDC->getASTContext(); auto args = ltDecl->getParameters(); auto aParam = args->get(0); auto bParam = args->get(1); Type enumType = aParam->getType(); auto enumDecl = cast(aParam->getType()->getAnyNominal()); SmallVector statements; SmallVector cases; unsigned elementCount = 0; // need this as `getAllElements` returns a generator // For each enum element, generate a case statement matching a pair containing // the same case, binding variables for the left- and right-hand associated // values. for (auto elt : enumDecl->getAllElements()) { elementCount++; // .(let l0, let l1, ...) SmallVector lhsPayloadVars; auto lhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'l', ltDecl, lhsPayloadVars); auto lhsElemPat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType), SourceLoc(), DeclNameLoc(), DeclNameRef(), elt, lhsSubpattern); lhsElemPat->setImplicit(); // .(let r0, let r1, ...) SmallVector rhsPayloadVars; auto rhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'r', ltDecl, rhsPayloadVars); auto rhsElemPat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType), SourceLoc(), DeclNameLoc(), DeclNameRef(), elt, rhsSubpattern); rhsElemPat->setImplicit(); auto hasBoundDecls = !lhsPayloadVars.empty(); Optional> caseBodyVarDecls; if (hasBoundDecls) { // We allocated a direct copy of our lhs var decls for the case // body. auto copy = C.Allocate(lhsPayloadVars.size()); for (unsigned i : indices(lhsPayloadVars)) { auto *vOld = lhsPayloadVars[i]; auto *vNew = new (C) VarDecl( /*IsStatic*/ false, vOld->getIntroducer(), false /*IsCaptureList*/, vOld->getNameLoc(), vOld->getName(), vOld->getDeclContext()); vNew->setHasNonPatternBindingInit(); vNew->setImplicit(); copy[i] = vNew; } caseBodyVarDecls.emplace(copy); } // case (.(let l0, let l1, ...), .(let r0, let r1, ...)) auto caseTuplePattern = TuplePattern::create(C, SourceLoc(), { TuplePatternElt(lhsElemPat), TuplePatternElt(rhsElemPat) }, SourceLoc()); caseTuplePattern->setImplicit(); auto labelItem = CaseLabelItem(caseTuplePattern); // Generate a guard statement for each associated value in the payload, // breaking out early if any pair is unequal. (same as Equatable synthesis.) // the else statement performs the lexicographic comparison. SmallVector statementsInCase; for (size_t varIdx = 0; varIdx < lhsPayloadVars.size(); varIdx++) { auto lhsVar = lhsPayloadVars[varIdx]; auto lhsExpr = new (C) DeclRefExpr(lhsVar, DeclNameLoc(), /*implicit*/true); auto rhsVar = rhsPayloadVars[varIdx]; auto rhsExpr = new (C) DeclRefExpr(rhsVar, DeclNameLoc(), /*Implicit*/true); auto guardStmt = DerivedConformance::returnComparisonIfNotEqualGuard(C, lhsExpr, rhsExpr); statementsInCase.emplace_back(guardStmt); } // If none of the guard statements caused an early exit, then all the pairs // were true. (equal) // return false auto falseExpr = new (C) BooleanLiteralExpr(false, SourceLoc(), /*Implicit*/true); auto returnStmt = new (C) ReturnStmt(SourceLoc(), falseExpr); statementsInCase.push_back(returnStmt); auto body = BraceStmt::create(C, SourceLoc(), statementsInCase, SourceLoc()); cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), labelItem, SourceLoc(), SourceLoc(), body, caseBodyVarDecls)); } // default: result = (lhs) < (rhs) // // We only generate this if the enum has more than one case. If it has exactly // one case, then that single case statement is already exhaustive. if (elementCount > 1) { auto defaultPattern = new (C) AnyPattern(SourceLoc()); defaultPattern->setImplicit(); auto defaultItem = CaseLabelItem::getDefault(defaultPattern); auto body = deriveBodyComparable_enum_noAssociatedValues_lt(ltDecl, nullptr).first; cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), defaultItem, SourceLoc(), SourceLoc(), body, /*case body var decls*/ None)); } // switch (a, b) { } auto aRef = new (C) DeclRefExpr(aParam, DeclNameLoc(), /*implicit*/true); auto bRef = new (C) DeclRefExpr(bParam, DeclNameLoc(), /*implicit*/true); auto abExpr = TupleExpr::create(C, SourceLoc(), { aRef, bRef }, {}, {}, SourceLoc(), /*HasTrailingClosure*/ false, /*implicit*/ true); auto switchStmt = SwitchStmt::create(LabeledStmtInfo(), SourceLoc(), abExpr, SourceLoc(), cases, SourceLoc(), C); statements.push_back(switchStmt); auto body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc()); return { body, /*isTypeChecked=*/false }; } /// Derive an '<' operator implementation for an enum. static ValueDecl * deriveComparable_lt( DerivedConformance &derived, std::pair (*bodySynthesizer)(AbstractFunctionDecl *, void *)) { ASTContext &C = derived.Context; auto parentDC = derived.getConformanceContext(); auto selfIfaceTy = parentDC->getDeclaredInterfaceType(); auto getParamDecl = [&](StringRef s) -> ParamDecl * { auto *param = new (C) ParamDecl(SourceLoc(), SourceLoc(), Identifier(), SourceLoc(), C.getIdentifier(s), parentDC); param->setSpecifier(ParamSpecifier::Default); param->setInterfaceType(selfIfaceTy); return param; }; ParameterList *params = ParameterList::create(C, { getParamDecl("a"), getParamDecl("b") }); auto boolTy = C.getBoolDecl()->getDeclaredType(); Identifier generatedIdentifier; if (parentDC->getParentModule()->isResilient()) { generatedIdentifier = C.Id_LessThanOperator; } else { assert(selfIfaceTy->getEnumOrBoundGenericEnum()); generatedIdentifier = C.Id_derived_enum_less_than; } DeclName name(C, generatedIdentifier, params); auto comparableDecl = FuncDecl::create(C, /*StaticLoc=*/SourceLoc(), StaticSpellingKind::KeywordStatic, /*FuncLoc=*/SourceLoc(), name, /*NameLoc=*/SourceLoc(), /*Throws=*/false, /*ThrowsLoc=*/SourceLoc(), /*GenericParams=*/nullptr, params, TypeLoc::withoutLoc(boolTy), parentDC); comparableDecl->setImplicit(); comparableDecl->setUserAccessible(false); // Add the @_implements(Comparable, < (_:_:)) attribute if (generatedIdentifier != C.Id_LessThanOperator) { auto comparable = C.getProtocol(KnownProtocolKind::Comparable); auto comparableType = comparable->getDeclaredType(); auto comparableTypeLoc = TypeLoc::withoutLoc(comparableType); SmallVector argumentLabels = { Identifier(), Identifier() }; auto comparableDeclName = DeclName(C, DeclBaseName(C.Id_LessThanOperator), argumentLabels); comparableDecl->getAttrs().add(new (C) ImplementsAttr(SourceLoc(), SourceRange(), comparableTypeLoc, comparableDeclName, DeclNameLoc())); } if (!C.getLessThanIntDecl()) { derived.ConformanceDecl->diagnose(diag::no_less_than_overload_for_int); return nullptr; } comparableDecl->setBodySynthesizer(bodySynthesizer); comparableDecl->copyFormalAccessFrom(derived.Nominal, /*sourceIsParentContext*/ true); // Add the operator to the parent scope. derived.addMembersToConformanceContext({comparableDecl}); return comparableDecl; } // for now, only enums can synthesize `Comparable`, so this function can take // an `EnumDecl` instead of a `NominalTypeDecl` bool DerivedConformance::canDeriveComparable(DeclContext *context, EnumDecl *enumeration) { // The type must be an enum. if (!enumeration) { return false; } auto comparable = context->getASTContext().getProtocol(KnownProtocolKind::Comparable); if (!comparable) { return false; // not sure what should be done here instead } // The cases must not have non-comparable associated values or raw backing return allAssociatedValuesConformToProtocol(context, enumeration, comparable) && !enumeration->hasRawType(); } ValueDecl *DerivedConformance::deriveComparable(ValueDecl *requirement) { if (checkAndDiagnoseDisallowedContext(requirement)) { return nullptr; } if (requirement->getBaseName() != "<") { requirement->diagnose(diag::broken_comparable_requirement); return nullptr; } // Build the necessary decl. auto enumeration = dyn_cast(this->Nominal); assert(enumeration); std::pair (*synthesizer)(AbstractFunctionDecl *, void *); if (enumeration->hasCases()) { if (enumeration->hasOnlyCasesWithoutAssociatedValues()) { synthesizer = &deriveBodyComparable_enum_noAssociatedValues_lt; } else { synthesizer = &deriveBodyComparable_enum_hasAssociatedValues_lt; } } else { synthesizer = &deriveBodyComparable_enum_uninhabited_lt; } return deriveComparable_lt(*this, synthesizer); }