Files
swift-mirror/lib/Sema/DerivedConformanceComparable.cpp
Owen Voorhees 43e2d107e1 [SE-0276] Implement multi-pattern catch clauses
Like switch cases, a catch clause may now include a comma-
separated list of patterns. The body will be executed if any
one of those patterns is matched.

This patch replaces `CatchStmt` with `CaseStmt` as the children
of `DoCatchStmt` in the AST. This necessitates a number of changes
throughout the compiler, including:
- Parser & libsyntax support for the new syntax and AST structure
- Typechecking of multi-pattern catches, including those which
  contain bindings.
- SILGen support
- Code completion updates
- Profiler updates
- Name lookup changes
2020-04-04 09:28:26 -07:00

361 lines
15 KiB
C++

//===--- DerivedConformanceComparable.cpp - Derived Comparable -===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements implicit derivation of the Comparable protocol.
// (Most of this code is similar to code in `DerivedConformanceEquatableHashable.cpp`)
//
//===----------------------------------------------------------------------===//
#include "TypeChecker.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/Expr.h"
#include "swift/AST/Module.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/Types.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/raw_ostream.h"
#include "DerivedConformances.h"
using namespace swift;
static std::pair<BraceStmt *, bool>
deriveBodyComparable_enum_uninhabited_lt(AbstractFunctionDecl *ltDecl, void *) {
auto parentDC = ltDecl->getDeclContext();
ASTContext &C = parentDC->getASTContext();
auto args = ltDecl->getParameters();
auto aParam = args->get(0);
auto bParam = args->get(1);
assert(!cast<EnumDecl>(aParam->getType()->getAnyNominal())->hasCases());
assert(!cast<EnumDecl>(bParam->getType()->getAnyNominal())->hasCases());
SmallVector<ASTNode, 0> statements;
auto body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc());
return { body, /*isTypeChecked=*/true };
}
/// Derive the body for a '<' operator for an enum that has no associated
/// values. This generates code that converts each value to its integer ordinal
/// and compares them, which produces an optimal single icmp instruction.
static std::pair<BraceStmt *, bool>
deriveBodyComparable_enum_noAssociatedValues_lt(AbstractFunctionDecl *ltDecl,
void *) {
auto parentDC = ltDecl->getDeclContext();
ASTContext &C = parentDC->getASTContext();
auto args = ltDecl->getParameters();
auto aParam = args->get(0);
auto bParam = args->get(1);
auto enumDecl = cast<EnumDecl>(aParam->getType()->getAnyNominal());
// Generate the conversion from the enums to integer indices.
SmallVector<ASTNode, 8> statements;
DeclRefExpr *aIndex = DerivedConformance::convertEnumToIndex(statements, parentDC, enumDecl,
aParam, ltDecl, "index_a");
DeclRefExpr *bIndex = DerivedConformance::convertEnumToIndex(statements, parentDC, enumDecl,
bParam, ltDecl, "index_b");
// Generate the compare of the indices.
FuncDecl *cmpFunc = C.getLessThanIntDecl();
assert(cmpFunc && "should have a < for int as we already checked for it");
auto fnType = cmpFunc->getInterfaceType()->castTo<FunctionType>();
Expr *cmpFuncExpr;
if (cmpFunc->getDeclContext()->isTypeContext()) {
auto contextTy = cmpFunc->getDeclContext()->getSelfInterfaceType();
Expr *base = TypeExpr::createImplicitHack(SourceLoc(), contextTy, C);
Expr *ref = new (C) DeclRefExpr(cmpFunc, DeclNameLoc(), /*Implicit*/ true,
AccessSemantics::Ordinary, fnType);
fnType = fnType->getResult()->castTo<FunctionType>();
cmpFuncExpr = new (C) DotSyntaxCallExpr(ref, SourceLoc(), base, fnType);
cmpFuncExpr->setImplicit();
} else {
cmpFuncExpr = new (C) DeclRefExpr(cmpFunc, DeclNameLoc(),
/*implicit*/ true,
AccessSemantics::Ordinary,
fnType);
}
TupleTypeElt abTupleElts[2] = { aIndex->getType(), bIndex->getType() };
TupleExpr *abTuple = TupleExpr::create(C, SourceLoc(), { aIndex, bIndex },
{ }, { }, SourceLoc(),
/*HasTrailingClosure*/ false,
/*Implicit*/ true,
TupleType::get(abTupleElts, C));
auto *cmpExpr = new (C) BinaryExpr(
cmpFuncExpr, abTuple, /*implicit*/ true,
fnType->castTo<FunctionType>()->getResult());
statements.push_back(new (C) ReturnStmt(SourceLoc(), cmpExpr));
BraceStmt *body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc());
return { body, /*isTypeChecked=*/true };
}
/// Derive the body for an '==' operator for an enum where at least one of the
/// cases has associated values.
static std::pair<BraceStmt *, bool>
deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, void *) {
auto parentDC = ltDecl->getDeclContext();
ASTContext &C = parentDC->getASTContext();
auto args = ltDecl->getParameters();
auto aParam = args->get(0);
auto bParam = args->get(1);
Type enumType = aParam->getType();
auto enumDecl = cast<EnumDecl>(aParam->getType()->getAnyNominal());
SmallVector<ASTNode, 8> statements;
SmallVector<ASTNode, 4> cases;
unsigned elementCount = 0; // need this as `getAllElements` returns a generator
// For each enum element, generate a case statement matching a pair containing
// the same case, binding variables for the left- and right-hand associated
// values.
for (auto elt : enumDecl->getAllElements()) {
elementCount++;
// .<elt>(let l0, let l1, ...)
SmallVector<VarDecl*, 4> lhsPayloadVars;
auto lhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'l', ltDecl,
lhsPayloadVars);
auto lhsElemPat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
SourceLoc(), DeclNameLoc(),
DeclNameRef(), elt,
lhsSubpattern);
lhsElemPat->setImplicit();
// .<elt>(let r0, let r1, ...)
SmallVector<VarDecl*, 4> rhsPayloadVars;
auto rhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'r', ltDecl,
rhsPayloadVars);
auto rhsElemPat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
SourceLoc(), DeclNameLoc(),
DeclNameRef(), elt,
rhsSubpattern);
rhsElemPat->setImplicit();
auto hasBoundDecls = !lhsPayloadVars.empty();
Optional<MutableArrayRef<VarDecl *>> caseBodyVarDecls;
if (hasBoundDecls) {
// We allocated a direct copy of our lhs var decls for the case
// body.
auto copy = C.Allocate<VarDecl *>(lhsPayloadVars.size());
for (unsigned i : indices(lhsPayloadVars)) {
auto *vOld = lhsPayloadVars[i];
auto *vNew = new (C) VarDecl(
/*IsStatic*/ false, vOld->getIntroducer(), false /*IsCaptureList*/,
vOld->getNameLoc(), vOld->getName(), vOld->getDeclContext());
vNew->setHasNonPatternBindingInit();
vNew->setImplicit();
copy[i] = vNew;
}
caseBodyVarDecls.emplace(copy);
}
// case (.<elt>(let l0, let l1, ...), .<elt>(let r0, let r1, ...))
auto caseTuplePattern = TuplePattern::create(C, SourceLoc(), {
TuplePatternElt(lhsElemPat), TuplePatternElt(rhsElemPat) },
SourceLoc());
caseTuplePattern->setImplicit();
auto labelItem = CaseLabelItem(caseTuplePattern);
// Generate a guard statement for each associated value in the payload,
// breaking out early if any pair is unequal. (same as Equatable synthesis.)
// the else statement performs the lexicographic comparison.
SmallVector<ASTNode, 8> statementsInCase;
for (size_t varIdx = 0; varIdx < lhsPayloadVars.size(); varIdx++) {
auto lhsVar = lhsPayloadVars[varIdx];
auto lhsExpr = new (C) DeclRefExpr(lhsVar, DeclNameLoc(),
/*implicit*/true);
auto rhsVar = rhsPayloadVars[varIdx];
auto rhsExpr = new (C) DeclRefExpr(rhsVar, DeclNameLoc(),
/*Implicit*/true);
auto guardStmt = DerivedConformance::returnComparisonIfNotEqualGuard(C,
lhsExpr, rhsExpr);
statementsInCase.emplace_back(guardStmt);
}
// If none of the guard statements caused an early exit, then all the pairs
// were true. (equal)
// return false
auto falseExpr = new (C) BooleanLiteralExpr(false, SourceLoc(),
/*Implicit*/true);
auto returnStmt = new (C) ReturnStmt(SourceLoc(), falseExpr);
statementsInCase.push_back(returnStmt);
auto body = BraceStmt::create(C, SourceLoc(), statementsInCase,
SourceLoc());
cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(),
labelItem, SourceLoc(), SourceLoc(), body,
caseBodyVarDecls));
}
// default: result = <enum index>(lhs) < <enum index>(rhs)
//
// We only generate this if the enum has more than one case. If it has exactly
// one case, then that single case statement is already exhaustive.
if (elementCount > 1) {
auto defaultPattern = new (C) AnyPattern(SourceLoc());
defaultPattern->setImplicit();
auto defaultItem = CaseLabelItem::getDefault(defaultPattern);
auto body = deriveBodyComparable_enum_noAssociatedValues_lt(ltDecl, nullptr).first;
cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(),
defaultItem, SourceLoc(), SourceLoc(),
body,
/*case body var decls*/ None));
}
// switch (a, b) { <case statements> }
auto aRef = new (C) DeclRefExpr(aParam, DeclNameLoc(), /*implicit*/true);
auto bRef = new (C) DeclRefExpr(bParam, DeclNameLoc(), /*implicit*/true);
auto abExpr = TupleExpr::create(C, SourceLoc(), { aRef, bRef }, {}, {},
SourceLoc(), /*HasTrailingClosure*/ false,
/*implicit*/ true);
auto switchStmt = SwitchStmt::create(LabeledStmtInfo(), SourceLoc(), abExpr,
SourceLoc(), cases, SourceLoc(), C);
statements.push_back(switchStmt);
auto body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc());
return { body, /*isTypeChecked=*/false };
}
/// Derive an '<' operator implementation for an enum.
static ValueDecl *
deriveComparable_lt(
DerivedConformance &derived,
std::pair<BraceStmt *, bool> (*bodySynthesizer)(AbstractFunctionDecl *,
void *)) {
ASTContext &C = derived.Context;
auto parentDC = derived.getConformanceContext();
auto selfIfaceTy = parentDC->getDeclaredInterfaceType();
auto getParamDecl = [&](StringRef s) -> ParamDecl * {
auto *param = new (C) ParamDecl(SourceLoc(),
SourceLoc(), Identifier(), SourceLoc(),
C.getIdentifier(s), parentDC);
param->setSpecifier(ParamSpecifier::Default);
param->setInterfaceType(selfIfaceTy);
return param;
};
ParameterList *params = ParameterList::create(C, {
getParamDecl("a"),
getParamDecl("b")
});
auto boolTy = C.getBoolDecl()->getDeclaredType();
Identifier generatedIdentifier;
if (parentDC->getParentModule()->isResilient()) {
generatedIdentifier = C.Id_LessThanOperator;
} else {
assert(selfIfaceTy->getEnumOrBoundGenericEnum());
generatedIdentifier = C.Id_derived_enum_less_than;
}
DeclName name(C, generatedIdentifier, params);
auto comparableDecl =
FuncDecl::create(C, /*StaticLoc=*/SourceLoc(),
StaticSpellingKind::KeywordStatic,
/*FuncLoc=*/SourceLoc(), name, /*NameLoc=*/SourceLoc(),
/*Throws=*/false, /*ThrowsLoc=*/SourceLoc(),
/*GenericParams=*/nullptr,
params,
TypeLoc::withoutLoc(boolTy),
parentDC);
comparableDecl->setImplicit();
comparableDecl->setUserAccessible(false);
// Add the @_implements(Comparable, < (_:_:)) attribute
if (generatedIdentifier != C.Id_LessThanOperator) {
auto comparable = C.getProtocol(KnownProtocolKind::Comparable);
auto comparableType = comparable->getDeclaredType();
auto comparableTypeLoc = TypeLoc::withoutLoc(comparableType);
SmallVector<Identifier, 2> argumentLabels = { Identifier(), Identifier() };
auto comparableDeclName = DeclName(C, DeclBaseName(C.Id_LessThanOperator),
argumentLabels);
comparableDecl->getAttrs().add(new (C) ImplementsAttr(SourceLoc(),
SourceRange(),
comparableTypeLoc,
comparableDeclName,
DeclNameLoc()));
}
if (!C.getLessThanIntDecl()) {
derived.ConformanceDecl->diagnose(diag::no_less_than_overload_for_int);
return nullptr;
}
comparableDecl->setBodySynthesizer(bodySynthesizer);
comparableDecl->copyFormalAccessFrom(derived.Nominal, /*sourceIsParentContext*/ true);
// Add the operator to the parent scope.
derived.addMembersToConformanceContext({comparableDecl});
return comparableDecl;
}
// for now, only enums can synthesize `Comparable`, so this function can take
// an `EnumDecl` instead of a `NominalTypeDecl`
bool
DerivedConformance::canDeriveComparable(DeclContext *context, EnumDecl *enumeration) {
// The type must be an enum.
if (!enumeration) {
return false;
}
auto comparable = context->getASTContext().getProtocol(KnownProtocolKind::Comparable);
if (!comparable) {
return false; // not sure what should be done here instead
}
// The cases must not have non-comparable associated values or raw backing
return allAssociatedValuesConformToProtocol(context, enumeration, comparable) && !enumeration->hasRawType();
}
ValueDecl *DerivedConformance::deriveComparable(ValueDecl *requirement) {
if (checkAndDiagnoseDisallowedContext(requirement)) {
return nullptr;
}
if (requirement->getBaseName() != "<") {
requirement->diagnose(diag::broken_comparable_requirement);
return nullptr;
}
// Build the necessary decl.
auto enumeration = dyn_cast<EnumDecl>(this->Nominal);
assert(enumeration);
std::pair<BraceStmt *, bool> (*synthesizer)(AbstractFunctionDecl *, void *);
if (enumeration->hasCases()) {
if (enumeration->hasOnlyCasesWithoutAssociatedValues()) {
synthesizer = &deriveBodyComparable_enum_noAssociatedValues_lt;
} else {
synthesizer = &deriveBodyComparable_enum_hasAssociatedValues_lt;
}
} else {
synthesizer = &deriveBodyComparable_enum_uninhabited_lt;
}
return deriveComparable_lt(*this, synthesizer);
}