//===--- ClangDerivedConformances.cpp -------------------------------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2022 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// #include "ClangDerivedConformances.h" #include "swift/AST/ParameterList.h" #include "swift/AST/PrettyStackTrace.h" #include "swift/AST/ProtocolConformance.h" #include "swift/ClangImporter/ClangImporterRequests.h" #include "clang/Sema/DelayedDiagnostic.h" #include "clang/Sema/Overload.h" using namespace swift; using namespace swift::importer; /// Alternative to `NominalTypeDecl::lookupDirect`. /// This function does not attempt to load extensions of the nominal decl. static TinyPtrVector lookupDirectWithoutExtensions(NominalTypeDecl *decl, Identifier id) { ASTContext &ctx = decl->getASTContext(); auto *importer = static_cast(ctx.getClangModuleLoader()); TinyPtrVector result; if (id.isOperator()) { auto underlyingId = ctx.getIdentifier(getPrivateOperatorName(std::string(id))); TinyPtrVector underlyingFuncs = evaluateOrDefault( ctx.evaluator, ClangRecordMemberLookup({decl, underlyingId}), {}); for (auto it : underlyingFuncs) { if (auto synthesizedFunc = importer->getCXXSynthesizedOperatorFunc(cast(it))) result.push_back(synthesizedFunc); } } else { // See if there is a Clang decl with the given name. result = evaluateOrDefault(ctx.evaluator, ClangRecordMemberLookup({decl, id}), {}); } // Check if there are any synthesized Swift members that match the name. for (auto member : decl->getCurrentMembersWithoutLoading()) { if (auto namedMember = dyn_cast(member)) { if (namedMember->hasName() && !namedMember->getName().isSpecial() && namedMember->getName().getBaseIdentifier().is(id.str()) && // Make sure we don't add duplicate entries, as that would wrongly // imply that lookup is ambiguous. !llvm::is_contained(result, namedMember)) { result.push_back(namedMember); } } } return result; } template static Decl *lookupDirectSingleWithoutExtensions(NominalTypeDecl *decl, Identifier id) { auto results = lookupDirectWithoutExtensions(decl, id); if (results.size() != 1) return nullptr; return dyn_cast(results.front()); } /// Similar to ModuleDecl::conformsToProtocol, but doesn't introduce a /// dependency on Sema. static bool isConcreteAndValid(ProtocolConformanceRef conformanceRef, ModuleDecl *module) { if (conformanceRef.isInvalid()) return false; if (!conformanceRef.isConcrete()) return false; auto conformance = conformanceRef.getConcrete(); auto subMap = conformance->getSubstitutionMap(); return llvm::all_of(subMap.getConformances(), [&](ProtocolConformanceRef each) -> bool { return isConcreteAndValid(each, module); }); } static bool isStdDecl(const clang::CXXRecordDecl *clangDecl, llvm::ArrayRef names) { if (!clangDecl->isInStdNamespace()) return false; if (!clangDecl->getIdentifier()) return false; StringRef name = clangDecl->getName(); return llvm::is_contained(names, name); } static clang::TypeDecl * getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) { clang::IdentifierInfo *iteratorCategoryDeclName = &clangDecl->getASTContext().Idents.get("iterator_category"); auto iteratorCategories = clangDecl->lookup(iteratorCategoryDeclName); if (!iteratorCategories.isSingleResult()) return nullptr; auto iteratorCategory = iteratorCategories.front(); return dyn_cast_or_null(iteratorCategory); } static ValueDecl *lookupOperator(NominalTypeDecl *decl, Identifier id, function_ref isValid) { // First look for operator declared as a member. auto memberResults = lookupDirectWithoutExtensions(decl, id); for (const auto &member : memberResults) { if (isValid(member)) return member; } // If no member operator was found, look for out-of-class definitions in the // same module. auto module = decl->getModuleContext(); SmallVector nonMemberResults; module->lookupValue(id, NLKind::UnqualifiedLookup, nonMemberResults); for (const auto &nonMember : nonMemberResults) { if (isValid(nonMember)) return nonMember; } return nullptr; } static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) { auto isValid = [&](ValueDecl *equalEqualOp) -> bool { auto equalEqual = dyn_cast(equalEqualOp); if (!equalEqual || !equalEqual->hasParameterList()) return false; auto params = equalEqual->getParameters(); if (params->size() != 2) return false; auto lhs = params->get(0); auto rhs = params->get(1); if (lhs->isInOut() || rhs->isInOut()) return false; auto lhsTy = lhs->getType(); auto rhsTy = rhs->getType(); if (!lhsTy || !rhsTy) return false; auto lhsNominal = lhsTy->getAnyNominal(); auto rhsNominal = rhsTy->getAnyNominal(); if (lhsNominal != rhsNominal || lhsNominal != decl) return false; return true; }; return lookupOperator(decl, decl->getASTContext().Id_EqualsOperator, isValid); } static ValueDecl *getMinusOperator(NominalTypeDecl *decl) { auto binaryIntegerProto = decl->getASTContext().getProtocol(KnownProtocolKind::BinaryInteger); auto module = decl->getModuleContext(); auto isValid = [&](ValueDecl *minusOp) -> bool { auto minus = dyn_cast(minusOp); if (!minus || !minus->hasParameterList()) return false; auto params = minus->getParameters(); if (params->size() != 2) return false; auto lhs = params->get(0); auto rhs = params->get(1); if (lhs->isInOut() || rhs->isInOut()) return false; auto lhsTy = lhs->getType(); auto rhsTy = rhs->getType(); if (!lhsTy || !rhsTy) return false; auto lhsNominal = lhsTy->getAnyNominal(); auto rhsNominal = rhsTy->getAnyNominal(); if (lhsNominal != rhsNominal || lhsNominal != decl) return false; auto returnTy = minus->getResultInterfaceType(); auto conformanceRef = module->lookupConformance(returnTy, binaryIntegerProto); if (!isConcreteAndValid(conformanceRef, module)) return false; return true; }; return lookupOperator(decl, decl->getASTContext().getIdentifier("-"), isValid); } static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) { auto isValid = [&](ValueDecl *plusEqualOp) -> bool { auto plusEqual = dyn_cast(plusEqualOp); if (!plusEqual || !plusEqual->hasParameterList()) return false; auto params = plusEqual->getParameters(); if (params->size() != 2) return false; auto lhs = params->get(0); auto rhs = params->get(1); if (rhs->isInOut()) return false; auto lhsTy = lhs->getType(); auto rhsTy = rhs->getType(); if (!lhsTy || !rhsTy) return false; if (rhsTy->getCanonicalType() != distanceTy->getCanonicalType()) return false; auto lhsNominal = lhsTy->getAnyNominal(); if (lhsNominal != decl) return false; auto returnTy = plusEqual->getResultInterfaceType(); if (!returnTy->isVoid()) return false; return true; }; return lookupOperator(decl, decl->getASTContext().getIdentifier("+="), isValid); } static void instantiateTemplatedOperator( ClangImporter::Implementation &impl, const clang::ClassTemplateSpecializationDecl *classDecl, clang::BinaryOperatorKind operatorKind) { clang::ASTContext &clangCtx = impl.getClangASTContext(); clang::Sema &clangSema = impl.getClangSema(); clang::UnresolvedSet<1> ops; auto qualType = clang::QualType(classDecl->getTypeForDecl(), 0); auto arg = new (clangCtx) clang::CXXThisExpr(clang::SourceLocation(), qualType, false); arg->setType(clang::QualType(classDecl->getTypeForDecl(), 0)); clang::OverloadedOperatorKind opKind = clang::BinaryOperator::getOverloadedOperator(operatorKind); clang::OverloadCandidateSet candidateSet( classDecl->getLocation(), clang::OverloadCandidateSet::CSK_Operator, clang::OverloadCandidateSet::OperatorRewriteInfo(opKind, clang::SourceLocation(), false)); clangSema.LookupOverloadedBinOp(candidateSet, opKind, ops, {arg, arg}, true); clang::OverloadCandidateSet::iterator best; switch (candidateSet.BestViableFunction(clangSema, clang::SourceLocation(), best)) { case clang::OR_Success: { if (auto clangCallee = best->Function) { auto lookupTable = impl.findLookupTable(classDecl); addEntryToLookupTable(*lookupTable, clangCallee, impl.getNameImporter()); } break; } case clang::OR_No_Viable_Function: case clang::OR_Ambiguous: case clang::OR_Deleted: break; } } bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) { return getIteratorCategoryDecl(clangDecl); } void swift::conformToCxxIteratorIfNeeded( ClangImporter::Implementation &impl, NominalTypeDecl *decl, const clang::CXXRecordDecl *clangDecl) { PrettyStackTraceDecl trace("conforming to UnsafeCxxInputIterator", decl); assert(decl); assert(clangDecl); ASTContext &ctx = decl->getASTContext(); if (!ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator)) return; // We consider a type to be an input iterator if it defines an // `iterator_category` that inherits from `std::input_iterator_tag`, e.g. // `using iterator_category = std::input_iterator_tag`. auto iteratorCategory = getIteratorCategoryDecl(clangDecl); if (!iteratorCategory) return; // If `iterator_category` is a typedef or a using-decl, retrieve the // underlying struct decl. clang::CXXRecordDecl *underlyingCategoryDecl = nullptr; if (auto typedefDecl = dyn_cast(iteratorCategory)) { auto type = typedefDecl->getUnderlyingType(); underlyingCategoryDecl = type->getAsCXXRecordDecl(); } else { underlyingCategoryDecl = dyn_cast(iteratorCategory); } if (underlyingCategoryDecl) { underlyingCategoryDecl = underlyingCategoryDecl->getDefinition(); } if (!underlyingCategoryDecl) return; auto isIteratorCategoryDecl = [&](const clang::CXXRecordDecl *base, StringRef tag) { return base->isInStdNamespace() && base->getIdentifier() && base->getName() == tag; }; auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) { return isIteratorCategoryDecl(base, "input_iterator_tag"); }; auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) { return isIteratorCategoryDecl(base, "random_access_iterator_tag"); }; // Traverse all transitive bases of `underlyingDecl` to check if // it inherits from `std::input_iterator_tag`. bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl); bool isRandomAccessIterator = isRandomAccessIteratorDecl(underlyingCategoryDecl); underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) { if (isInputIteratorDecl(base)) { isInputIterator = true; } if (isRandomAccessIteratorDecl(base)) { isRandomAccessIterator = true; isInputIterator = true; return false; } return true; }); if (!isInputIterator) return; // Check if present: `var pointee: Pointee { get }` auto pointeeId = ctx.getIdentifier("pointee"); auto pointee = lookupDirectSingleWithoutExtensions(decl, pointeeId); if (!pointee || pointee->isGetterMutating() || pointee->getType()->hasError()) return; // Check if present: `func successor() -> Self` auto successorId = ctx.getIdentifier("successor"); auto successor = lookupDirectSingleWithoutExtensions(decl, successorId); if (!successor || successor->isMutating()) return; auto successorTy = successor->getResultInterfaceType(); if (!successorTy || successorTy->getAnyNominal() != decl) return; // If this is a templated class, `operator==` might be templated as well. // Try to instantiate it. if (auto templateSpec = dyn_cast(clangDecl)) { instantiateTemplatedOperator(impl, templateSpec, clang::BinaryOperatorKind::BO_EQ); } // Check if present: `func ==` auto equalEqual = getEqualEqualOperator(decl); if (!equalEqual) return; impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Pointee"), pointee->getType()); impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::UnsafeCxxInputIterator}); if (!isRandomAccessIterator || !ctx.getProtocol(KnownProtocolKind::UnsafeCxxRandomAccessIterator)) return; // Try to conform to UnsafeCxxRandomAccessIterator if possible. if (auto templateSpec = dyn_cast(clangDecl)) { instantiateTemplatedOperator(impl, templateSpec, clang::BinaryOperatorKind::BO_Sub); } auto minus = dyn_cast_or_null(getMinusOperator(decl)); if (!minus) return; auto distanceTy = minus->getResultInterfaceType(); // distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator. auto plusEqual = dyn_cast_or_null(getPlusEqualOperator(decl, distanceTy)); if (!plusEqual) return; impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Distance"), distanceTy); impl.addSynthesizedProtocolAttrs( decl, {KnownProtocolKind::UnsafeCxxRandomAccessIterator}); } void swift::conformToCxxOptionalIfNeeded( ClangImporter::Implementation &impl, NominalTypeDecl *decl, const clang::CXXRecordDecl *clangDecl) { PrettyStackTraceDecl trace("conforming to CxxOptional", decl); assert(decl); assert(clangDecl); ASTContext &ctx = decl->getASTContext(); if (!isStdDecl(clangDecl, {"optional"})) return; ProtocolDecl *cxxOptionalProto = ctx.getProtocol(KnownProtocolKind::CxxOptional); // If the Cxx module is missing, or does not include one of the necessary // protocol, bail. if (!cxxOptionalProto) return; auto pointeeId = ctx.getIdentifier("pointee"); auto pointees = lookupDirectWithoutExtensions(decl, pointeeId); if (pointees.size() != 1) return; auto pointee = dyn_cast(pointees.front()); if (!pointee) return; auto pointeeTy = pointee->getInterfaceType(); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Wrapped"), pointeeTy); impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxOptional}); } void swift::conformToCxxSequenceIfNeeded( ClangImporter::Implementation &impl, NominalTypeDecl *decl, const clang::CXXRecordDecl *clangDecl) { PrettyStackTraceDecl trace("conforming to CxxSequence", decl); assert(decl); assert(clangDecl); ASTContext &ctx = decl->getASTContext(); ProtocolDecl *cxxIteratorProto = ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator); ProtocolDecl *cxxSequenceProto = ctx.getProtocol(KnownProtocolKind::CxxSequence); ProtocolDecl *cxxConvertibleProto = ctx.getProtocol(KnownProtocolKind::CxxConvertibleToCollection); // If the Cxx module is missing, or does not include one of the necessary // protocols, bail. if (!cxxIteratorProto || !cxxSequenceProto) return; // Check if present: `func __beginUnsafe() -> RawIterator` auto beginId = ctx.getIdentifier("__beginUnsafe"); auto begin = lookupDirectSingleWithoutExtensions(decl, beginId); if (!begin) return; auto rawIteratorTy = begin->getResultInterfaceType(); // Check if present: `func __endUnsafe() -> RawIterator` auto endId = ctx.getIdentifier("__endUnsafe"); auto end = lookupDirectSingleWithoutExtensions(decl, endId); if (!end) return; // Check if `begin()` and `end()` are non-mutating. if (begin->isMutating() || end->isMutating()) return; // Check if `__beginUnsafe` and `__endUnsafe` have the same return type. auto endTy = end->getResultInterfaceType(); if (!endTy || endTy->getCanonicalType() != rawIteratorTy->getCanonicalType()) return; // Check if RawIterator conforms to UnsafeCxxInputIterator. ModuleDecl *module = decl->getModuleContext(); auto rawIteratorConformanceRef = module->lookupConformance(rawIteratorTy, cxxIteratorProto); if (!isConcreteAndValid(rawIteratorConformanceRef, module)) return; auto rawIteratorConformance = rawIteratorConformanceRef.getConcrete(); auto pointeeDecl = cxxIteratorProto->getAssociatedType(ctx.getIdentifier("Pointee")); assert(pointeeDecl && "UnsafeCxxInputIterator must have a Pointee associated type"); auto pointeeTy = rawIteratorConformance->getTypeWitness(pointeeDecl); assert(pointeeTy && "valid conformance must have a Pointee witness"); // Take the default definition of `Iterator` from CxxSequence protocol. This // type is currently `CxxIterator`. auto iteratorDecl = cxxSequenceProto->getAssociatedType(ctx.Id_Iterator); auto iteratorTy = iteratorDecl->getDefaultDefinitionType(); // Substitute generic `Self` parameter. auto cxxSequenceSelfTy = cxxSequenceProto->getSelfInterfaceType(); auto declSelfTy = decl->getDeclaredInterfaceType(); iteratorTy = iteratorTy.subst( [&](SubstitutableType *dependentType) { if (dependentType->isEqual(cxxSequenceSelfTy)) return declSelfTy; return Type(dependentType); }, LookUpConformanceInModule(module)); impl.addSynthesizedTypealias(decl, ctx.Id_Element, pointeeTy); impl.addSynthesizedTypealias(decl, ctx.Id_Iterator, iteratorTy); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawIterator"), rawIteratorTy); // Not conforming the type to CxxSequence protocol here: // The current implementation of CxxSequence triggers extra copies of the C++ // collection when creating a CxxIterator instance. It needs a more efficient // implementation, which is not possible with the existing Swift features. // impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSequence}); // Try to conform to CxxRandomAccessCollection if possible. auto tryToConformToRandomAccessCollection = [&]() -> bool { auto cxxRAIteratorProto = ctx.getProtocol(KnownProtocolKind::UnsafeCxxRandomAccessIterator); if (!cxxRAIteratorProto || !ctx.getProtocol(KnownProtocolKind::CxxRandomAccessCollection)) return false; // Check if RawIterator conforms to UnsafeCxxRandomAccessIterator. auto rawIteratorRAConformanceRef = decl->getModuleContext()->lookupConformance(rawIteratorTy, cxxRAIteratorProto); if (!isConcreteAndValid(rawIteratorRAConformanceRef, module)) return false; // CxxRandomAccessCollection always uses Int as an Index. auto indexTy = ctx.getIntType(); auto sliceTy = ctx.getSliceType(); sliceTy = sliceTy.subst( [&](SubstitutableType *dependentType) { if (dependentType->isEqual(cxxSequenceSelfTy)) return declSelfTy; return Type(dependentType); }, LookUpConformanceInModule(module)); auto indicesTy = ctx.getRangeType(); indicesTy = indicesTy.subst( [&](SubstitutableType *dependentType) { if (dependentType->isEqual(cxxSequenceSelfTy)) return indexTy; return Type(dependentType); }, LookUpConformanceInModule(module)); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Element"), pointeeTy); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Index"), indexTy); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Indices"), indicesTy); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("SubSequence"), sliceTy); impl.addSynthesizedProtocolAttrs( decl, {KnownProtocolKind::CxxRandomAccessCollection}); return true; }; bool conformedToRAC = tryToConformToRandomAccessCollection(); // If the collection does not support random access, let's still allow the // developer to explicitly convert a C++ sequence to a Swift Array (making a // copy of the sequence's elements) by conforming the type to // CxxCollectionConvertible. This enables an overload of Array.init declared // in the Cxx module. if (!conformedToRAC && cxxConvertibleProto) { impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Element"), pointeeTy); impl.addSynthesizedProtocolAttrs( decl, {KnownProtocolKind::CxxConvertibleToCollection}); } } void swift::conformToCxxSetIfNeeded(ClangImporter::Implementation &impl, NominalTypeDecl *decl, const clang::CXXRecordDecl *clangDecl) { PrettyStackTraceDecl trace("conforming to CxxSet", decl); assert(decl); assert(clangDecl); ASTContext &ctx = decl->getASTContext(); // Only auto-conform types from the C++ standard library. Custom user types // might have a similar interface but different semantics. if (!isStdDecl(clangDecl, {"set", "unordered_set", "multiset"})) return; auto valueType = lookupDirectSingleWithoutExtensions( decl, ctx.getIdentifier("value_type")); auto sizeType = lookupDirectSingleWithoutExtensions( decl, ctx.getIdentifier("size_type")); if (!valueType || !sizeType) return; impl.addSynthesizedTypealias(decl, ctx.Id_Element, valueType->getUnderlyingType()); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Size"), sizeType->getUnderlyingType()); impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSet}); } void swift::conformToCxxPairIfNeeded(ClangImporter::Implementation &impl, NominalTypeDecl *decl, const clang::CXXRecordDecl *clangDecl) { PrettyStackTraceDecl trace("conforming to CxxPair", decl); assert(decl); assert(clangDecl); ASTContext &ctx = decl->getASTContext(); // Only auto-conform types from the C++ standard library. Custom user types // might have a similar interface but different semantics. if (!isStdDecl(clangDecl, {"pair"})) return; auto firstType = lookupDirectSingleWithoutExtensions( decl, ctx.getIdentifier("first_type")); auto secondType = lookupDirectSingleWithoutExtensions( decl, ctx.getIdentifier("second_type")); if (!firstType || !secondType) return; impl.addSynthesizedTypealias(decl, ctx.getIdentifier("First"), firstType->getUnderlyingType()); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Second"), secondType->getUnderlyingType()); impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxPair}); } void swift::conformToCxxDictionaryIfNeeded( ClangImporter::Implementation &impl, NominalTypeDecl *decl, const clang::CXXRecordDecl *clangDecl) { PrettyStackTraceDecl trace("conforming to CxxDictionary", decl); assert(decl); assert(clangDecl); ASTContext &ctx = decl->getASTContext(); // Only auto-conform types from the C++ standard library. Custom user types // might have a similar interface but different semantics. if (!isStdDecl(clangDecl, {"map", "unordered_map"})) return; auto keyType = lookupDirectSingleWithoutExtensions( decl, ctx.getIdentifier("key_type")); auto valueType = lookupDirectSingleWithoutExtensions( decl, ctx.getIdentifier("mapped_type")); auto iterType = lookupDirectSingleWithoutExtensions( decl, ctx.getIdentifier("const_iterator")); if (!keyType || !valueType || !iterType) return; // Make the original subscript that returns a non-optional value unavailable. // CxxDictionary adds another subscript that returns an optional value, // similarly to Swift.Dictionary. for (auto member : decl->getCurrentMembersWithoutLoading()) { if (auto subscript = dyn_cast(member)) { impl.markUnavailable(subscript, "use subscript with optional return value"); } } impl.addSynthesizedTypealias(decl, ctx.Id_Key, keyType->getUnderlyingType()); impl.addSynthesizedTypealias(decl, ctx.Id_Value, valueType->getUnderlyingType()); impl.addSynthesizedTypealias(decl, ctx.getIdentifier("RawIterator"), iterType->getUnderlyingType()); impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxDictionary}); }