//===--- DerivedConformanceEquatableHashable.cpp - Derived Equatable & co. ===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2015 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See http://swift.org/LICENSE.txt for license information // See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements implicit derivation of the Equatable and Hashable // protocols. (Comparable is similar enough in spirit that it would make // sense to live here too when we implement its derivation.) // //===----------------------------------------------------------------------===// #include "TypeChecker.h" #include "swift/AST/ArchetypeBuilder.h" #include "swift/AST/Decl.h" #include "swift/AST/Stmt.h" #include "swift/AST/Expr.h" #include "swift/AST/Types.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/raw_ostream.h" #include "DerivedConformances.h" using namespace swift; using namespace DerivedConformance; /// Common preconditions for Equatable and Hashable. static bool canDeriveConformance(NominalTypeDecl *type) { // The type must be an enum. // TODO: Structs with Equatable/Hashable/Comparable members auto enumDecl = dyn_cast(type); if (!enumDecl) return false; // The enum must be simple. // TODO: Enums with Equatable/Hashable/Comparable payloads if (!enumDecl->isSimpleEnum()) return false; return true; } static Expr *getTrueExpr(ASTContext &C) { auto decl = C.getTrueDecl(); return new (C) DeclRefExpr(decl, SourceLoc(), /*implicit*/ true, /*direct access*/false, decl->getType()); } static Expr *getFalseExpr(ASTContext &C) { auto decl = C.getFalseDecl(); return new (C) DeclRefExpr(decl, SourceLoc(), /*implicit*/ true, /*direct access*/false, decl->getType()); } static void deriveBodyEquatable_enum_eq(AbstractFunctionDecl *eqDecl) { auto args = cast(eqDecl->getBodyParamPatterns().back()); auto aPattern = args->getFields()[0].getPattern(); auto aParamPattern = cast(aPattern->getSemanticsProvidingPattern()); auto aParam = aParamPattern->getDecl(); auto bPattern = args->getFields()[1].getPattern(); auto bParamPattern = cast(bPattern->getSemanticsProvidingPattern()); auto bParam = bParamPattern->getDecl(); auto enumDecl = cast(aParam->getType()->getAnyNominal()); ASTContext &C = enumDecl->getASTContext(); auto enumTy = enumDecl->getDeclaredTypeInContext(); SmallVector cases; SmallVector caseLabelItems; for (auto elt : enumDecl->getAllElements()) { assert(!elt->hasArgumentType() && "enums with payloads not supported yet"); auto aPat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumTy), SourceLoc(), SourceLoc(), Identifier(), elt, nullptr); aPat->setImplicit(); auto bPat = aPat->clone(C, Pattern::Implicit); TuplePatternElt tupleElts[] = { TuplePatternElt(aPat), TuplePatternElt(bPat) }; auto tuplePat = TuplePattern::create(C, SourceLoc(), tupleElts, SourceLoc()); tuplePat->setImplicit(); caseLabelItems.push_back( CaseLabelItem(/*IsDefault=*/false, tuplePat, SourceLoc(), nullptr)); } { Expr *trueExpr = getTrueExpr(C); auto returnStmt = new (C) ReturnStmt(SourceLoc(), trueExpr); BraceStmt* body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt), SourceLoc()); cases.push_back( CaseStmt::create(C, SourceLoc(), caseLabelItems, /*HasBoundDecls=*/false, SourceLoc(), body)); } { auto any = new (C) AnyPattern(SourceLoc()); any->setImplicit(); auto labelItem = CaseLabelItem(/*IsDefault=*/true, any, SourceLoc(), nullptr); Expr *falseExpr = getFalseExpr(C); auto returnStmt = new (C) ReturnStmt(SourceLoc(), falseExpr); BraceStmt* body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt), SourceLoc()); cases.push_back( CaseStmt::create(C, SourceLoc(), labelItem, /*HasBoundDecls=*/false, SourceLoc(), body)); } auto aRef = new (C) DeclRefExpr(aParam, SourceLoc(), /*implicit*/ true); auto bRef = new (C) DeclRefExpr(bParam, SourceLoc(), /*implicit*/ true); Expr *ab[] = {aRef, bRef}; TupleExpr *abTuple = TupleExpr::createImplicit(C, ab, { }); auto switchStmt = SwitchStmt::create(LabeledStmtInfo(), SourceLoc(), abTuple, SourceLoc(), cases, SourceLoc(), C); BraceStmt *body = BraceStmt::create(C, SourceLoc(), ASTNode(switchStmt), SourceLoc()); eqDecl->setBody(body); } /// Derive an '==' operator implementation for an enum. static ValueDecl * deriveEquatable_enum_eq(TypeChecker &tc, EnumDecl *enumDecl) { // enum SomeEnum { // case A, B, C // } // @derived // func ==(a: SomeEnum, b: SomeEnum) -> Bool { // switch (a, b) { // case (.A, .A): // case (.B, .B): // case (.C, .C): // return true // case _: // return false // } // } ASTContext &C = tc.Context; auto enumTy = enumDecl->getDeclaredTypeInContext(); auto getParamPattern = [&](StringRef s) -> std::pair { VarDecl *aDecl = new (C) ParamDecl(/*isLet*/ true, SourceLoc(), C.getIdentifier(s), SourceLoc(), C.getIdentifier(s), enumTy, enumDecl); aDecl->setImplicit(); Pattern *aParam = new (C) NamedPattern(aDecl, /*implicit*/ true); aParam->setType(enumTy); aParam = new (C) TypedPattern(aParam, TypeLoc::withoutLoc(enumTy)); aParam->setType(enumTy); aParam->setImplicit(); return {aDecl, aParam}; }; auto aParam = getParamPattern("a"); auto bParam = getParamPattern("b"); TupleTypeElt typeElts[] = { TupleTypeElt(enumTy), TupleTypeElt(enumTy) }; auto paramsTy = TupleType::get(typeElts, C); TuplePatternElt paramElts[] = { TuplePatternElt(aParam.second), TuplePatternElt(bParam.second), }; auto params = TuplePattern::create(C, SourceLoc(), paramElts, SourceLoc()); params->setImplicit(); params->setType(paramsTy); auto genericParams = enumDecl->getGenericParamsOfContext(); auto boolTy = C.getBoolDecl()->getDeclaredType(); DeclName name(C, C.Id_EqualsOperator, { Identifier(), Identifier() }); auto eqDecl = FuncDecl::create(C, SourceLoc(), StaticSpellingKind::None, SourceLoc(), name, SourceLoc(), genericParams, Type(), params, TypeLoc::withoutLoc(boolTy), &enumDecl->getModuleContext()->getDerivedFileUnit()); eqDecl->setImplicit(); eqDecl->getMutableAttrs().setAttr(AttrKind::AK_infix, SourceLoc()); auto op = C.getStdlibModule()->lookupInfixOperator(C.Id_EqualsOperator); if (!op) { tc.diagnose(enumDecl->getLoc(), diag::broken_equatable_eq_operator); return nullptr; } eqDecl->setOperatorDecl(op); eqDecl->setDerivedForTypeDecl(enumDecl); eqDecl->setBodySynthesizer(&deriveBodyEquatable_enum_eq); // Compute the type and interface type. Type fnTy, interfaceTy; if (genericParams) { fnTy = PolymorphicFunctionType::get(paramsTy, boolTy, genericParams); auto enumIfaceTy = enumDecl->getDeclaredInterfaceType(); TupleTypeElt ifaceParamElts[] = { enumIfaceTy, enumIfaceTy, }; auto ifaceParamsTy = TupleType::get(ifaceParamElts, C); interfaceTy = GenericFunctionType::get( enumDecl->getGenericSignatureOfContext(), ifaceParamsTy, boolTy, AnyFunctionType::ExtInfo()); } else { fnTy = interfaceTy = FunctionType::get(paramsTy, boolTy); } eqDecl->setType(fnTy); eqDecl->setInterfaceType(interfaceTy); if (enumDecl->hasClangNode()) tc.implicitlyDefinedFunctions.push_back(eqDecl); // Since it's an operator we insert the decl after the type at global scope. return insertOperatorDecl(enumDecl, eqDecl); } ValueDecl *DerivedConformance::deriveEquatable(TypeChecker &tc, NominalTypeDecl *type, ValueDecl *requirement) { // Check that we can actually derive Equatable for this type. if (!canDeriveConformance(type)) return nullptr; // Build the necessary decl. if (requirement->getName().str() == "==") { if (auto theEnum = dyn_cast(type)) return deriveEquatable_enum_eq(tc, theEnum); else llvm_unreachable("todo"); } tc.diagnose(requirement->getLoc(), diag::broken_equatable_requirement); return nullptr; } static void deriveBodyHashable_enum_hashValue(AbstractFunctionDecl *hashValueDecl) { auto enumDecl = cast(hashValueDecl->getDeclContext()); ASTContext &C = enumDecl->getASTContext(); auto enumType = enumDecl->getDeclaredTypeInContext(); Type intType = C.getIntDecl()->getDeclaredType(); auto indexVar = new (C) ParamDecl(/*let*/false, SourceLoc(), C.getIdentifier("index"), SourceLoc(), C.getIdentifier("index"), intType, hashValueDecl); indexVar->setImplicit(); Pattern *indexPat = new (C) NamedPattern(indexVar, /*implicit*/ true); indexPat->setType(intType); indexPat = new (C) TypedPattern(indexPat, TypeLoc::withoutLoc(intType)); indexPat->setType(intType); auto indexBind = new (C) PatternBindingDecl(SourceLoc(), StaticSpellingKind::None, SourceLoc(), indexPat, nullptr, /*conditional*/ false, enumDecl); unsigned index = 0; SmallVector cases; for (auto elt : enumDecl->getAllElements()) { auto pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType), SourceLoc(), SourceLoc(), Identifier(), elt, nullptr); pat->setImplicit(); auto labelItem = CaseLabelItem(/*IsDefault=*/false, pat, SourceLoc(), nullptr); llvm::SmallString<8> indexVal; APInt(32, index++).toString(indexVal, 10, /*signed*/ false); auto indexStr = C.AllocateCopy(indexVal); auto indexExpr = new (C) IntegerLiteralExpr( StringRef(indexStr.data(), indexStr.size()), SourceLoc(), /*implicit*/ true); auto indexRef = new (C) DeclRefExpr(indexVar, SourceLoc(), /*implicit*/true); auto assignExpr = new (C) AssignExpr(indexRef, SourceLoc(), indexExpr, /*implicit*/ true); auto body = BraceStmt::create(C, SourceLoc(), ASTNode(assignExpr), SourceLoc()); cases.push_back( CaseStmt::create(C, SourceLoc(), labelItem, /*HasBoundDecls=*/false, SourceLoc(), body)); } Pattern *curriedArgs = hashValueDecl->getBodyParamPatterns().front(); auto selfPattern = cast(curriedArgs->getSemanticsProvidingPattern()); auto selfDecl = selfPattern->getDecl(); auto selfRef = new (C) DeclRefExpr(selfDecl, SourceLoc(), /*implicit*/true); auto switchStmt = SwitchStmt::create(LabeledStmtInfo(), SourceLoc(), selfRef, SourceLoc(), cases, SourceLoc(), C); auto indexRef = new (C) DeclRefExpr(indexVar, SourceLoc(), /*implicit*/ true); auto memberRef = new (C) UnresolvedDotExpr(indexRef, SourceLoc(), C.getIdentifier("hashValue"), SourceLoc(), /*implicit*/true); auto returnStmt = new (C) ReturnStmt(SourceLoc(), memberRef); ASTNode bodyStmts[] = { indexBind, switchStmt, returnStmt, }; auto body = BraceStmt::create(C, SourceLoc(), bodyStmts, SourceLoc()); hashValueDecl->setBody(body); } /// Derive a 'hashValue' implementation for an enum. static ValueDecl * deriveHashable_enum_hashValue(TypeChecker &tc, EnumDecl *enumDecl) { // enum SomeEnum { // case A, B, C // @derived var hashValue: Int { // var index: Int // switch self { // case A: // index = 0 // case B: // index = 1 // case C: // index = 2 // } // return index.hashValue // } // } ASTContext &C = tc.Context; Type enumType = enumDecl->getDeclaredTypeInContext(); Type intType = C.getIntDecl()->getDeclaredType(); // We can't form a Hashable conformance if Int isn't Hashable or // IntegerLiteralConvertible. if (!tc.conformsToProtocol(intType, C.getProtocol(KnownProtocolKind::Hashable), enumDecl->getModuleContext())) { tc.diagnose(enumDecl->getLoc(), diag::broken_int_hashable_conformance); return nullptr; } if (!tc.conformsToProtocol(intType, C.getProtocol(KnownProtocolKind::IntegerLiteralConvertible), enumDecl->getModuleContext())) { tc.diagnose(enumDecl->getLoc(), diag::broken_int_integer_literal_convertible_conformance); return nullptr; } VarDecl *selfDecl = new (C) ParamDecl(/*IsLet*/true, SourceLoc(), Identifier(), SourceLoc(), C.Id_self, enumType, enumDecl); selfDecl->setImplicit(); Pattern *selfParam = new (C) NamedPattern(selfDecl, /*implicit*/ true); selfParam->setType(enumType); selfParam = new (C) TypedPattern(selfParam, TypeLoc::withoutLoc(enumType)); selfParam->setType(enumType); Pattern *methodParam = TuplePattern::create(C, SourceLoc(),{},SourceLoc()); methodParam->setType(TupleType::getEmpty(tc.Context)); Pattern *params[] = {selfParam, methodParam}; FuncDecl *getterDecl = FuncDecl::create(C, SourceLoc(), StaticSpellingKind::None, SourceLoc(), Identifier(), SourceLoc(), nullptr, Type(), params, TypeLoc::withoutLoc(intType), enumDecl); getterDecl->setImplicit(); getterDecl->setBodySynthesizer(deriveBodyHashable_enum_hashValue); // Compute the type of hashValue(). GenericParamList *genericParams = nullptr; Type methodType = FunctionType::get(TupleType::getEmpty(tc.Context), intType); Type selfType = getterDecl->computeSelfType(&genericParams); Type type; if (genericParams) type = PolymorphicFunctionType::get(selfType, methodType, genericParams); else type = FunctionType::get(selfType, methodType); getterDecl->setType(type); getterDecl->setBodyResultType(intType); // Compute the interface type of hashValue(). Type interfaceType; Type selfIfaceType = getterDecl->computeInterfaceSelfType(false); if (auto sig = enumDecl->getGenericSignatureOfContext()) interfaceType = GenericFunctionType::get(sig, selfIfaceType, methodType, AnyFunctionType::ExtInfo()); else interfaceType = type; getterDecl->setInterfaceType(interfaceType); if (enumDecl->hasClangNode()) tc.implicitlyDefinedFunctions.push_back(getterDecl); // Create the property. VarDecl *hashValueDecl = new (C) VarDecl(/*static*/ false, /*let*/ false, SourceLoc(), C.Id_hashValue, intType, enumDecl); hashValueDecl->setImplicit(); hashValueDecl->makeComputed(SourceLoc(), getterDecl, nullptr, SourceLoc()); Pattern *hashValuePat = new (C) NamedPattern(hashValueDecl, /*implicit*/true); hashValuePat->setType(intType); hashValuePat = new (C) TypedPattern(hashValuePat, TypeLoc::withoutLoc(intType), /*implicit*/ true); hashValuePat->setType(intType); auto patDecl = new (C) PatternBindingDecl(SourceLoc(), StaticSpellingKind::None, SourceLoc(), hashValuePat, nullptr, /*conditional*/true, enumDecl); patDecl->setImplicit(); enumDecl->addMember(getterDecl); enumDecl->addMember(hashValueDecl); enumDecl->addMember(patDecl); return hashValueDecl; } ValueDecl *DerivedConformance::deriveHashable(TypeChecker &tc, NominalTypeDecl *type, ValueDecl *requirement) { // Check that we can actually derive Hashable for this type. if (!canDeriveConformance(type)) return nullptr; // Build the necessary decl. if (requirement->getName().str() == "hashValue") { if (auto theEnum = dyn_cast(type)) return deriveHashable_enum_hashValue(tc, theEnum); else llvm_unreachable("todo"); } tc.diagnose(requirement->getLoc(), diag::broken_hashable_requirement); return nullptr; }