//===--- DerivedConformanceEquatableHashable.cpp - Derived Equatable & co. ===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2015 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See http://swift.org/LICENSE.txt for license information // See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements implicit derivation of the Equatable and Hashable // protocols. (Comparable is similar enough in spirit that it would make // sense to live here too when we implement its derivation.) // //===----------------------------------------------------------------------===// #include "TypeChecker.h" #include "llvm/ADT/APInt.h" #include "llvm/Support/raw_ostream.h" #include "swift/AST/ArchetypeBuilder.h" #include "swift/AST/Decl.h" #include "swift/AST/Stmt.h" #include "swift/AST/Expr.h" #include "swift/AST/Types.h" #include "DerivedConformances.h" using namespace swift; using namespace DerivedConformance; /// Common preconditions for Equatable and Hashable. static bool canDeriveConformance(NominalTypeDecl *type) { // The type must be an enum. // TODO: Structs with Equatable/Hashable/Comparable members auto enumDecl = dyn_cast(type); if (!enumDecl) return false; // The enum must be simple. // TODO: Enums with Equatable/Hashable/Comparable payloads if (!enumDecl->isSimpleEnum()) return false; return true; } static Expr *getTrueExpr(ASTContext &C) { auto decl = C.getTrueDecl(); return new (C) DeclRefExpr(decl, SourceLoc(), /*implicit*/ true, /*direct access*/false, decl->getType()); } static Expr *getFalseExpr(ASTContext &C) { auto decl = C.getFalseDecl(); return new (C) DeclRefExpr(decl, SourceLoc(), /*implicit*/ true, /*direct access*/false, decl->getType()); } /// Derive an '==' operator implementation for an enum. static ValueDecl * deriveEquatable_enum_eq(TypeChecker &tc, EnumDecl *enumDecl) { // enum SomeEnum { // case A, B, C // } // @derived // func ==(a: SomeEnum, b: SomeEnum) -> Bool { // switch (a, b) { // case (.A, .A): // case (.B, .B): // case (.C, .C): // return true // case _: // return false // } // } ASTContext &C = tc.Context; auto enumTy = enumDecl->getDeclaredTypeInContext(); auto getParamPattern = [&](StringRef s) -> std::pair { VarDecl *aDecl = new (C) VarDecl(/*static*/ false, /*isLet*/ true, SourceLoc(), C.getIdentifier(s), enumTy, enumDecl); aDecl->setImplicit(); Pattern *aParam = new (C) NamedPattern(aDecl, /*implicit*/ true); aParam->setType(enumTy); aParam = new (C) TypedPattern(aParam, TypeLoc::withoutLoc(enumTy)); aParam->setType(enumTy); aParam->setImplicit(); return {aDecl, aParam}; }; auto aParam = getParamPattern("a"); auto bParam = getParamPattern("b"); TupleTypeElt typeElts[] = { TupleTypeElt(enumTy), TupleTypeElt(enumTy) }; auto paramsTy = TupleType::get(typeElts, C); TuplePatternElt paramElts[] = { TuplePatternElt(aParam.second), TuplePatternElt(bParam.second), }; auto params = TuplePattern::create(C, SourceLoc(), paramElts, SourceLoc()); params->setImplicit(); params->setType(paramsTy); Pattern *argParams = params->clone(C, Pattern::Implicit); auto genericParams = enumDecl->getGenericParamsOfContext(); auto boolTy = C.getBoolDecl()->getDeclaredType(); auto id_eq = C.getIdentifier("=="); auto eqDecl = FuncDecl::create(C, SourceLoc(), StaticSpellingKind::None, SourceLoc(), id_eq, SourceLoc(), genericParams, Type(), argParams, params, TypeLoc::withoutLoc(boolTy), &enumDecl->getModuleContext()->getDerivedFileUnit()); eqDecl->setImplicit(); eqDecl->getMutableAttrs().setAttr(AttrKind::AK_infix, SourceLoc()); auto op = C.getStdlibModule()->lookupInfixOperator(id_eq); if (!op) { tc.diagnose(enumDecl->getLoc(), diag::broken_equatable_eq_operator); return nullptr; } eqDecl->setOperatorDecl(op); eqDecl->setDerivedForTypeDecl(enumDecl); SmallVector cases; SmallVector caseLabelItems; for (auto elt : enumDecl->getAllElements()) { assert(!elt->hasArgumentType() && "enums with payloads not supported yet"); auto aPat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumTy), SourceLoc(), SourceLoc(), Identifier(), elt, nullptr); aPat->setImplicit(); auto bPat = aPat->clone(C, Pattern::Implicit); TuplePatternElt tupleElts[] = { TuplePatternElt(aPat), TuplePatternElt(bPat) }; auto tuplePat = TuplePattern::create(C, SourceLoc(), tupleElts, SourceLoc()); tuplePat->setImplicit(); caseLabelItems.push_back( CaseLabelItem(/*IsDefault=*/false, tuplePat, SourceLoc(), nullptr)); } { Expr *trueExpr = getTrueExpr(C); auto returnStmt = new (C) ReturnStmt(SourceLoc(), trueExpr); BraceStmt* body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt), SourceLoc()); cases.push_back( CaseStmt::create(C, SourceLoc(), caseLabelItems, /*HasBoundDecls=*/false, SourceLoc(), body)); } { auto any = new (C) AnyPattern(SourceLoc()); any->setImplicit(); auto labelItem = CaseLabelItem(/*IsDefault=*/true, any, SourceLoc(), nullptr); Expr *falseExpr = getFalseExpr(C); auto returnStmt = new (C) ReturnStmt(SourceLoc(), falseExpr); BraceStmt* body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt), SourceLoc()); cases.push_back( CaseStmt::create(C, SourceLoc(), labelItem, /*HasBoundDecls=*/false, SourceLoc(), body)); } auto aRef = new (C) DeclRefExpr(aParam.first, SourceLoc(), /*implicit*/ true); auto bRef = new (C) DeclRefExpr(bParam.first, SourceLoc(), /*implicit*/ true); Expr *ab[] = {aRef, bRef}; TupleExpr *abTuple = new (C) TupleExpr(SourceLoc(), C.AllocateCopy(ab), nullptr, SourceLoc(), /*trailingClosure*/ false, /*implicit*/ true); auto switchStmt = SwitchStmt::create(SourceLoc(), abTuple, SourceLoc(), cases, SourceLoc(), C); BraceStmt *body = BraceStmt::create(C, SourceLoc(), ASTNode(switchStmt), SourceLoc()); eqDecl->setBody(body); // Compute the type and interface type. Type fnTy, interfaceTy; if (genericParams) { fnTy = PolymorphicFunctionType::get(paramsTy, boolTy, genericParams); auto enumIfaceTy = enumDecl->getDeclaredInterfaceType(); TupleTypeElt ifaceParamElts[] = { enumIfaceTy, enumIfaceTy, }; auto ifaceParamsTy = TupleType::get(ifaceParamElts, C); interfaceTy = GenericFunctionType::get( enumDecl->getGenericSignatureOfContext(), ifaceParamsTy, boolTy, AnyFunctionType::ExtInfo()); } else { fnTy = interfaceTy = FunctionType::get(paramsTy, boolTy); } eqDecl->setType(fnTy); eqDecl->setInterfaceType(interfaceTy); tc.implicitlyDefinedFunctions.push_back(eqDecl); // Since it's an operator we insert the decl after the type at global scope. return insertOperatorDecl(enumDecl, eqDecl); } ValueDecl *DerivedConformance::deriveEquatable(TypeChecker &tc, NominalTypeDecl *type, ValueDecl *requirement) { // Check that we can actually derive Equatable for this type. if (!canDeriveConformance(type)) return nullptr; // Build the necessary decl. if (requirement->getName().str() == "==") { if (auto theEnum = dyn_cast(type)) return deriveEquatable_enum_eq(tc, theEnum); else llvm_unreachable("todo"); } tc.diagnose(requirement->getLoc(), diag::broken_equatable_requirement); return nullptr; } /// Derive a 'hashValue' implementation for an enum. static ValueDecl * deriveHashable_enum_hashValue(TypeChecker &tc, EnumDecl *enumDecl) { // enum SomeEnum { // case A, B, C // @derived func hashValue() -> Int { // var index: Int // switch self { // case A: // index = 0 // case B: // index = 1 // case C: // index = 2 // } // return index.hashValue() // } // } ASTContext &C = tc.Context; Type enumType = enumDecl->getDeclaredTypeInContext(); Type intType = C.getIntDecl()->getDeclaredType(); // We can't form a Hashable conformance if Int isn't Hashable or // IntegerLiteralConvertible. if (!tc.conformsToProtocol(intType, C.getProtocol(KnownProtocolKind::Hashable), enumDecl->getModuleContext())) { tc.diagnose(enumDecl->getLoc(), diag::broken_int_hashable_conformance); return nullptr; } if (!tc.conformsToProtocol(intType, C.getProtocol(KnownProtocolKind::IntegerLiteralConvertible), enumDecl->getModuleContext())) { tc.diagnose(enumDecl->getLoc(), diag::broken_int_integer_literal_convertible_conformance); return nullptr; } VarDecl *selfDecl = new (C) VarDecl(/*static*/ false, /*IsLet*/true, SourceLoc(), C.Id_self, enumType, enumDecl); selfDecl->setImplicit(); Pattern *selfParam = new (C) NamedPattern(selfDecl, /*implicit*/ true); selfParam->setType(enumType); selfParam = new (C) TypedPattern(selfParam, TypeLoc::withoutLoc(enumType)); selfParam->setType(enumType); Pattern *methodParam = TuplePattern::create(C, SourceLoc(),{},SourceLoc()); methodParam->setType(TupleType::getEmpty(tc.Context)); Pattern *params[] = {selfParam, methodParam}; Identifier id_hashValue = C.getIdentifier("hashValue"); FuncDecl *hashValueDecl = FuncDecl::create(C, SourceLoc(), StaticSpellingKind::None, SourceLoc(), id_hashValue, SourceLoc(), nullptr, Type(), params, params, TypeLoc::withoutLoc(intType), enumDecl); hashValueDecl->setImplicit(); auto indexVar = new (C) VarDecl(/*static*/false, /*let*/false, SourceLoc(), C.getIdentifier("index"), intType, hashValueDecl); indexVar->setImplicit(); Pattern *indexPat = new (C) NamedPattern(indexVar, /*implicit*/ true); indexPat = new (C) TypedPattern(indexPat, TypeLoc::withoutLoc(intType)); auto indexBind = new (C) PatternBindingDecl(SourceLoc(), StaticSpellingKind::None, SourceLoc(), indexPat, nullptr, /*storage*/ true, /*conditional*/ false, enumDecl); unsigned index = 0; SmallVector cases; for (auto elt : enumDecl->getAllElements()) { auto pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType), SourceLoc(), SourceLoc(), Identifier(), elt, nullptr); pat->setImplicit(); auto labelItem = CaseLabelItem(/*IsDefault=*/false, pat, SourceLoc(), nullptr); llvm::SmallString<8> indexVal; APInt(32, index++).toString(indexVal, 10, /*signed*/ false); auto indexStr = C.AllocateCopy(indexVal); auto indexExpr = new (C) IntegerLiteralExpr( StringRef(indexStr.data(), indexStr.size()), SourceLoc(), /*implicit*/ true); auto indexRef = new (C) DeclRefExpr(indexVar, SourceLoc(), /*implicit*/true); auto assignExpr = new (C) AssignExpr(indexRef, SourceLoc(), indexExpr, /*implicit*/ true); auto body = BraceStmt::create(C, SourceLoc(), ASTNode(assignExpr), SourceLoc()); cases.push_back( CaseStmt::create(C, SourceLoc(), labelItem, /*HasBoundDecls=*/false, SourceLoc(), body)); } auto selfRef = new (C) DeclRefExpr(selfDecl, SourceLoc(), /*implicit*/true); auto switchStmt = SwitchStmt::create(SourceLoc(), selfRef, SourceLoc(), cases, SourceLoc(), C); auto indexRef = new (C) DeclRefExpr(indexVar, SourceLoc(), /*implicit*/ true); auto memberRef = new (C) UnresolvedDotExpr(indexRef, SourceLoc(), id_hashValue, SourceLoc(), /*implicit*/ true); auto args = new (C) TupleExpr(SourceLoc(), SourceLoc(), /*implicit*/ true); auto call = new (C) CallExpr(memberRef, args, /*implicit*/true); auto returnStmt = new (C) ReturnStmt(SourceLoc(), call); ASTNode bodyStmts[] = { indexBind, switchStmt, returnStmt, }; auto body = BraceStmt::create(C, SourceLoc(), bodyStmts, SourceLoc()); hashValueDecl->setBody(body); // Compute the type of hashValue(). GenericParamList *genericParams = nullptr; Type methodType = FunctionType::get(TupleType::getEmpty(tc.Context), intType); Type selfType = hashValueDecl->computeSelfType(&genericParams); Type type; if (genericParams) type = PolymorphicFunctionType::get(selfType, methodType, genericParams); else type = FunctionType::get(selfType, methodType); hashValueDecl->setType(type); hashValueDecl->setBodyResultType(intType); // Compute the interface type of hashValue(). Type interfaceType; Type selfIfaceType = hashValueDecl->computeInterfaceSelfType(false); if (auto sig = enumDecl->getGenericSignatureOfContext()) interfaceType = GenericFunctionType::get(sig, selfIfaceType, methodType, AnyFunctionType::ExtInfo()); else interfaceType = type; hashValueDecl->setInterfaceType(interfaceType); tc.implicitlyDefinedFunctions.push_back(hashValueDecl); return insertMemberDecl(enumDecl, hashValueDecl); } ValueDecl *DerivedConformance::deriveHashable(TypeChecker &tc, NominalTypeDecl *type, ValueDecl *requirement) { // Check that we can actually derive Hashable for this type. if (!canDeriveConformance(type)) return nullptr; // Build the necessary decl. if (requirement->getName().str() == "hashValue") { if (auto theEnum = dyn_cast(type)) return deriveHashable_enum_hashValue(tc, theEnum); else llvm_unreachable("todo"); } tc.diagnose(requirement->getLoc(), diag::broken_hashable_requirement); return nullptr; }