mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
341 lines
14 KiB
C++
341 lines
14 KiB
C++
//===--- DerivedConformanceComparable.cpp - Derived Comparable -===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements implicit derivation of the Comparable protocol.
|
|
// (Most of this code is similar to code in `DerivedConformanceEquatableHashable.cpp`)
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "TypeChecker.h"
|
|
#include "swift/AST/Decl.h"
|
|
#include "swift/AST/Stmt.h"
|
|
#include "swift/AST/Expr.h"
|
|
#include "swift/AST/Module.h"
|
|
#include "swift/AST/Pattern.h"
|
|
#include "swift/AST/ParameterList.h"
|
|
#include "swift/AST/ProtocolConformance.h"
|
|
#include "swift/AST/Types.h"
|
|
#include "llvm/ADT/APInt.h"
|
|
#include "llvm/ADT/SmallString.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "DerivedConformances.h"
|
|
|
|
using namespace swift;
|
|
|
|
static std::pair<BraceStmt *, bool>
|
|
deriveBodyComparable_enum_uninhabited_lt(AbstractFunctionDecl *ltDecl, void *) {
|
|
auto parentDC = ltDecl->getDeclContext();
|
|
ASTContext &C = parentDC->getASTContext();
|
|
|
|
auto args = ltDecl->getParameters();
|
|
auto aParam = args->get(0);
|
|
auto bParam = args->get(1);
|
|
|
|
assert(!cast<EnumDecl>(aParam->getType()->getAnyNominal())->hasCases());
|
|
assert(!cast<EnumDecl>(bParam->getType()->getAnyNominal())->hasCases());
|
|
|
|
SmallVector<ASTNode, 0> statements;
|
|
auto body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc());
|
|
return { body, /*isTypeChecked=*/true };
|
|
}
|
|
|
|
/// Derive the body for a '<' operator for an enum that has no associated
|
|
/// values. This generates code that converts each value to its integer ordinal
|
|
/// and compares them, which produces an optimal single icmp instruction.
|
|
static std::pair<BraceStmt *, bool>
|
|
deriveBodyComparable_enum_noAssociatedValues_lt(AbstractFunctionDecl *ltDecl,
|
|
void *) {
|
|
auto parentDC = ltDecl->getDeclContext();
|
|
ASTContext &C = parentDC->getASTContext();
|
|
|
|
auto args = ltDecl->getParameters();
|
|
auto aParam = args->get(0);
|
|
auto bParam = args->get(1);
|
|
|
|
auto enumDecl = cast<EnumDecl>(aParam->getType()->getAnyNominal());
|
|
|
|
// Generate the conversion from the enums to integer indices.
|
|
SmallVector<ASTNode, 8> statements;
|
|
DeclRefExpr *aIndex = DerivedConformance::convertEnumToIndex(statements, parentDC, enumDecl,
|
|
aParam, ltDecl, "index_a");
|
|
DeclRefExpr *bIndex = DerivedConformance::convertEnumToIndex(statements, parentDC, enumDecl,
|
|
bParam, ltDecl, "index_b");
|
|
|
|
// Generate the compare of the indices.
|
|
FuncDecl *cmpFunc = C.getLessThanIntDecl();
|
|
assert(cmpFunc && "should have a < for int as we already checked for it");
|
|
|
|
Expr *cmpFuncExpr = new (C) DeclRefExpr(cmpFunc, DeclNameLoc(),
|
|
/*implicit*/ true,
|
|
AccessSemantics::Ordinary);
|
|
|
|
TupleExpr *abTuple = TupleExpr::create(C, SourceLoc(), { aIndex, bIndex },
|
|
{ }, { }, SourceLoc(),
|
|
/*HasTrailingClosure*/ false,
|
|
/*Implicit*/ true);
|
|
|
|
auto *cmpExpr = new (C) BinaryExpr(
|
|
cmpFuncExpr, abTuple, /*implicit*/ true);
|
|
statements.push_back(new (C) ReturnStmt(SourceLoc(), cmpExpr));
|
|
|
|
BraceStmt *body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc());
|
|
return { body, /*isTypeChecked=*/false };
|
|
}
|
|
|
|
/// Derive the body for an '==' operator for an enum where at least one of the
|
|
/// cases has associated values.
|
|
static std::pair<BraceStmt *, bool>
|
|
deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, void *) {
|
|
auto parentDC = ltDecl->getDeclContext();
|
|
ASTContext &C = parentDC->getASTContext();
|
|
|
|
auto args = ltDecl->getParameters();
|
|
auto aParam = args->get(0);
|
|
auto bParam = args->get(1);
|
|
|
|
Type enumType = aParam->getType();
|
|
auto enumDecl = cast<EnumDecl>(aParam->getType()->getAnyNominal());
|
|
|
|
SmallVector<ASTNode, 8> statements;
|
|
SmallVector<ASTNode, 4> cases;
|
|
unsigned elementCount = 0; // need this as `getAllElements` returns a generator
|
|
|
|
// For each enum element, generate a case statement matching a pair containing
|
|
// the same case, binding variables for the left- and right-hand associated
|
|
// values.
|
|
for (auto elt : enumDecl->getAllElements()) {
|
|
++elementCount;
|
|
|
|
// .<elt>(let l0, let l1, ...)
|
|
SmallVector<VarDecl*, 4> lhsPayloadVars;
|
|
auto lhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'l', ltDecl,
|
|
lhsPayloadVars);
|
|
auto *lhsBaseTE = TypeExpr::createImplicit(enumType, C);
|
|
auto lhsElemPat =
|
|
new (C) EnumElementPattern(lhsBaseTE, SourceLoc(), DeclNameLoc(),
|
|
DeclNameRef(), elt, lhsSubpattern);
|
|
lhsElemPat->setImplicit();
|
|
|
|
// .<elt>(let r0, let r1, ...)
|
|
SmallVector<VarDecl*, 4> rhsPayloadVars;
|
|
auto rhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'r', ltDecl,
|
|
rhsPayloadVars);
|
|
auto *rhsBaseTE = TypeExpr::createImplicit(enumType, C);
|
|
auto rhsElemPat =
|
|
new (C) EnumElementPattern(rhsBaseTE, SourceLoc(), DeclNameLoc(),
|
|
DeclNameRef(), elt, rhsSubpattern);
|
|
rhsElemPat->setImplicit();
|
|
|
|
auto hasBoundDecls = !lhsPayloadVars.empty();
|
|
Optional<MutableArrayRef<VarDecl *>> caseBodyVarDecls;
|
|
if (hasBoundDecls) {
|
|
// We allocated a direct copy of our lhs var decls for the case
|
|
// body.
|
|
auto copy = C.Allocate<VarDecl *>(lhsPayloadVars.size());
|
|
for (unsigned i : indices(lhsPayloadVars)) {
|
|
auto *vOld = lhsPayloadVars[i];
|
|
auto *vNew = new (C) VarDecl(
|
|
/*IsStatic*/ false, vOld->getIntroducer(), false /*IsCaptureList*/,
|
|
vOld->getNameLoc(), vOld->getName(), vOld->getDeclContext());
|
|
vNew->setHasNonPatternBindingInit();
|
|
vNew->setImplicit();
|
|
copy[i] = vNew;
|
|
}
|
|
caseBodyVarDecls.emplace(copy);
|
|
}
|
|
|
|
// case (.<elt>(let l0, let l1, ...), .<elt>(let r0, let r1, ...))
|
|
auto caseTuplePattern = TuplePattern::createImplicit(C, {
|
|
TuplePatternElt(lhsElemPat), TuplePatternElt(rhsElemPat) });
|
|
caseTuplePattern->setImplicit();
|
|
|
|
auto labelItem = CaseLabelItem(caseTuplePattern);
|
|
|
|
// Generate a guard statement for each associated value in the payload,
|
|
// breaking out early if any pair is unequal. (same as Equatable synthesis.)
|
|
// the else statement performs the lexicographic comparison.
|
|
SmallVector<ASTNode, 8> statementsInCase;
|
|
for (size_t varIdx = 0; varIdx < lhsPayloadVars.size(); ++varIdx) {
|
|
auto lhsVar = lhsPayloadVars[varIdx];
|
|
auto lhsExpr = new (C) DeclRefExpr(lhsVar, DeclNameLoc(),
|
|
/*implicit*/true);
|
|
auto rhsVar = rhsPayloadVars[varIdx];
|
|
auto rhsExpr = new (C) DeclRefExpr(rhsVar, DeclNameLoc(),
|
|
/*Implicit*/true);
|
|
auto guardStmt = DerivedConformance::returnComparisonIfNotEqualGuard(C,
|
|
lhsExpr, rhsExpr);
|
|
statementsInCase.emplace_back(guardStmt);
|
|
}
|
|
|
|
// If none of the guard statements caused an early exit, then all the pairs
|
|
// were true. (equal)
|
|
// return false
|
|
auto falseExpr = new (C) BooleanLiteralExpr(false, SourceLoc(),
|
|
/*Implicit*/true);
|
|
auto returnStmt = new (C) ReturnStmt(SourceLoc(), falseExpr);
|
|
statementsInCase.push_back(returnStmt);
|
|
|
|
auto body = BraceStmt::create(C, SourceLoc(), statementsInCase,
|
|
SourceLoc());
|
|
cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(),
|
|
labelItem, SourceLoc(), SourceLoc(), body,
|
|
caseBodyVarDecls));
|
|
}
|
|
|
|
// default: result = <enum index>(lhs) < <enum index>(rhs)
|
|
//
|
|
// We only generate this if the enum has more than one case. If it has exactly
|
|
// one case, then that single case statement is already exhaustive.
|
|
if (elementCount > 1) {
|
|
auto defaultPattern = AnyPattern::createImplicit(C);
|
|
auto defaultItem = CaseLabelItem::getDefault(defaultPattern);
|
|
auto body = deriveBodyComparable_enum_noAssociatedValues_lt(ltDecl, nullptr).first;
|
|
cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(),
|
|
defaultItem, SourceLoc(), SourceLoc(),
|
|
body,
|
|
/*case body var decls*/ None));
|
|
}
|
|
|
|
// switch (a, b) { <case statements> }
|
|
auto aRef = new (C) DeclRefExpr(aParam, DeclNameLoc(), /*implicit*/true);
|
|
auto bRef = new (C) DeclRefExpr(bParam, DeclNameLoc(), /*implicit*/true);
|
|
auto abExpr = TupleExpr::create(C, SourceLoc(), { aRef, bRef }, {}, {},
|
|
SourceLoc(), /*HasTrailingClosure*/ false,
|
|
/*implicit*/ true);
|
|
auto switchStmt = SwitchStmt::create(LabeledStmtInfo(), SourceLoc(), abExpr,
|
|
SourceLoc(), cases, SourceLoc(), C);
|
|
statements.push_back(switchStmt);
|
|
|
|
auto body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc());
|
|
return { body, /*isTypeChecked=*/false };
|
|
}
|
|
|
|
/// Derive an '<' operator implementation for an enum.
|
|
static ValueDecl *
|
|
deriveComparable_lt(
|
|
DerivedConformance &derived,
|
|
std::pair<BraceStmt *, bool> (*bodySynthesizer)(AbstractFunctionDecl *,
|
|
void *)) {
|
|
ASTContext &C = derived.Context;
|
|
|
|
auto parentDC = derived.getConformanceContext();
|
|
auto selfIfaceTy = parentDC->getDeclaredInterfaceType();
|
|
|
|
auto getParamDecl = [&](StringRef s) -> ParamDecl * {
|
|
auto *param = new (C) ParamDecl(SourceLoc(),
|
|
SourceLoc(), Identifier(), SourceLoc(),
|
|
C.getIdentifier(s), parentDC);
|
|
param->setSpecifier(ParamSpecifier::Default);
|
|
param->setInterfaceType(selfIfaceTy);
|
|
return param;
|
|
};
|
|
|
|
ParameterList *params = ParameterList::create(C, {
|
|
getParamDecl("a"),
|
|
getParamDecl("b")
|
|
});
|
|
|
|
auto boolTy = C.getBoolDecl()->getDeclaredType();
|
|
|
|
Identifier generatedIdentifier;
|
|
if (parentDC->getParentModule()->isResilient()) {
|
|
generatedIdentifier = C.Id_LessThanOperator;
|
|
} else {
|
|
assert(selfIfaceTy->getEnumOrBoundGenericEnum());
|
|
generatedIdentifier = C.Id_derived_enum_less_than;
|
|
}
|
|
|
|
DeclName name(C, generatedIdentifier, params);
|
|
auto comparableDecl =
|
|
FuncDecl::create(C, /*StaticLoc=*/SourceLoc(),
|
|
StaticSpellingKind::KeywordStatic,
|
|
/*FuncLoc=*/SourceLoc(), name, /*NameLoc=*/SourceLoc(),
|
|
/*Throws=*/false, /*ThrowsLoc=*/SourceLoc(),
|
|
/*GenericParams=*/nullptr,
|
|
params,
|
|
TypeLoc::withoutLoc(boolTy),
|
|
parentDC);
|
|
comparableDecl->setImplicit();
|
|
comparableDecl->setUserAccessible(false);
|
|
|
|
// Add the @_implements(Comparable, < (_:_:)) attribute
|
|
if (generatedIdentifier != C.Id_LessThanOperator) {
|
|
auto comparable = C.getProtocol(KnownProtocolKind::Comparable);
|
|
auto comparableType = comparable->getDeclaredType();
|
|
auto comparableTypeExpr = TypeExpr::createImplicit(comparableType, C);
|
|
SmallVector<Identifier, 2> argumentLabels = { Identifier(), Identifier() };
|
|
auto comparableDeclName = DeclName(C, DeclBaseName(C.Id_LessThanOperator),
|
|
argumentLabels);
|
|
comparableDecl->getAttrs().add(new (C) ImplementsAttr(SourceLoc(),
|
|
SourceRange(),
|
|
comparableTypeExpr,
|
|
comparableDeclName,
|
|
DeclNameLoc()));
|
|
}
|
|
|
|
if (!C.getLessThanIntDecl()) {
|
|
derived.ConformanceDecl->diagnose(diag::no_less_than_overload_for_int);
|
|
return nullptr;
|
|
}
|
|
|
|
comparableDecl->setBodySynthesizer(bodySynthesizer);
|
|
|
|
comparableDecl->copyFormalAccessFrom(derived.Nominal, /*sourceIsParentContext*/ true);
|
|
|
|
// Add the operator to the parent scope.
|
|
derived.addMembersToConformanceContext({comparableDecl});
|
|
|
|
return comparableDecl;
|
|
}
|
|
|
|
// for now, only enums can synthesize `Comparable`, so this function can take
|
|
// an `EnumDecl` instead of a `NominalTypeDecl`
|
|
bool
|
|
DerivedConformance::canDeriveComparable(DeclContext *context, EnumDecl *enumeration) {
|
|
// The type must be an enum.
|
|
if (!enumeration) {
|
|
return false;
|
|
}
|
|
auto comparable = context->getASTContext().getProtocol(KnownProtocolKind::Comparable);
|
|
if (!comparable) {
|
|
return false; // not sure what should be done here instead
|
|
}
|
|
// The cases must not have non-comparable associated values or raw backing
|
|
return allAssociatedValuesConformToProtocol(context, enumeration, comparable) && !enumeration->hasRawType();
|
|
}
|
|
|
|
ValueDecl *DerivedConformance::deriveComparable(ValueDecl *requirement) {
|
|
if (checkAndDiagnoseDisallowedContext(requirement)) {
|
|
return nullptr;
|
|
}
|
|
if (requirement->getBaseName() != "<") {
|
|
requirement->diagnose(diag::broken_comparable_requirement);
|
|
return nullptr;
|
|
}
|
|
|
|
// Build the necessary decl.
|
|
auto enumeration = dyn_cast<EnumDecl>(this->Nominal);
|
|
assert(enumeration);
|
|
|
|
std::pair<BraceStmt *, bool> (*synthesizer)(AbstractFunctionDecl *, void *);
|
|
if (enumeration->hasCases()) {
|
|
if (enumeration->hasOnlyCasesWithoutAssociatedValues()) {
|
|
synthesizer = &deriveBodyComparable_enum_noAssociatedValues_lt;
|
|
} else {
|
|
synthesizer = &deriveBodyComparable_enum_hasAssociatedValues_lt;
|
|
}
|
|
} else {
|
|
synthesizer = &deriveBodyComparable_enum_uninhabited_lt;
|
|
}
|
|
return deriveComparable_lt(*this, synthesizer);
|
|
}
|