mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Merge pull request #61700 from apple/egorzhdan/cxx-conform-raciter
[cxx-interop] Synthesize conformances to `UnsafeCxxInputIterator`
This commit is contained in:
@@ -107,6 +107,7 @@ PROTOCOL(DistributedTargetInvocationResultHandler)
|
||||
// C++ Standard Library Overlay:
|
||||
PROTOCOL(CxxSequence)
|
||||
PROTOCOL(UnsafeCxxInputIterator)
|
||||
PROTOCOL(UnsafeCxxRandomAccessIterator)
|
||||
|
||||
PROTOCOL(AsyncSequence)
|
||||
PROTOCOL(AsyncIteratorProtocol)
|
||||
|
||||
@@ -1115,6 +1115,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
|
||||
break;
|
||||
case KnownProtocolKind::CxxSequence:
|
||||
case KnownProtocolKind::UnsafeCxxInputIterator:
|
||||
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
|
||||
M = getLoadedModule(Id_Cxx);
|
||||
break;
|
||||
default:
|
||||
|
||||
@@ -53,9 +53,29 @@ getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
|
||||
return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory);
|
||||
}
|
||||
|
||||
static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
|
||||
auto id = decl->getASTContext().Id_EqualsOperator;
|
||||
static ValueDecl *lookupOperator(NominalTypeDecl *decl, Identifier id,
|
||||
function_ref<bool(ValueDecl *)> isValid) {
|
||||
// First look for operator declared as a member.
|
||||
auto memberResults = lookupDirectWithoutExtensions(decl, id);
|
||||
for (const auto &member : memberResults) {
|
||||
if (isValid(member))
|
||||
return member;
|
||||
}
|
||||
|
||||
// If no member operator was found, look for out-of-class definitions in the
|
||||
// same module.
|
||||
auto module = decl->getModuleContext();
|
||||
SmallVector<ValueDecl *> nonMemberResults;
|
||||
module->lookupValue(id, NLKind::UnqualifiedLookup, nonMemberResults);
|
||||
for (const auto &nonMember : nonMemberResults) {
|
||||
if (isValid(nonMember))
|
||||
return nonMember;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
|
||||
auto isValid = [&](ValueDecl *equalEqualOp) -> bool {
|
||||
auto equalEqual = dyn_cast<FuncDecl>(equalEqualOp);
|
||||
if (!equalEqual || !equalEqual->hasParameterList())
|
||||
@@ -78,24 +98,72 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
|
||||
return true;
|
||||
};
|
||||
|
||||
// First look for `func ==` declared as a member.
|
||||
auto memberResults = lookupDirectWithoutExtensions(decl, id);
|
||||
for (const auto &member : memberResults) {
|
||||
if (isValid(member))
|
||||
return member;
|
||||
}
|
||||
return lookupOperator(decl, decl->getASTContext().Id_EqualsOperator, isValid);
|
||||
}
|
||||
|
||||
// If no member `func ==` was found, look for out-of-class definitions in the
|
||||
// same module.
|
||||
static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
|
||||
auto binaryIntegerProto =
|
||||
decl->getASTContext().getProtocol(KnownProtocolKind::BinaryInteger);
|
||||
auto module = decl->getModuleContext();
|
||||
SmallVector<ValueDecl *> nonMemberResults;
|
||||
module->lookupValue(id, NLKind::UnqualifiedLookup, nonMemberResults);
|
||||
for (const auto &nonMember : nonMemberResults) {
|
||||
if (isValid(nonMember))
|
||||
return nonMember;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
auto isValid = [&](ValueDecl *minusOp) -> bool {
|
||||
auto minus = dyn_cast<FuncDecl>(minusOp);
|
||||
if (!minus || !minus->hasParameterList())
|
||||
return false;
|
||||
auto params = minus->getParameters();
|
||||
if (params->size() != 2)
|
||||
return false;
|
||||
auto lhs = params->get(0);
|
||||
auto rhs = params->get(1);
|
||||
if (lhs->isInOut() || rhs->isInOut())
|
||||
return false;
|
||||
auto lhsTy = lhs->getType();
|
||||
auto rhsTy = rhs->getType();
|
||||
if (!lhsTy || !rhsTy)
|
||||
return false;
|
||||
auto lhsNominal = lhsTy->getAnyNominal();
|
||||
auto rhsNominal = rhsTy->getAnyNominal();
|
||||
if (lhsNominal != rhsNominal || lhsNominal != decl)
|
||||
return false;
|
||||
auto returnTy = minus->getResultInterfaceType();
|
||||
if (!module->conformsToProtocol(returnTy, binaryIntegerProto))
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
|
||||
return lookupOperator(decl, decl->getASTContext().getIdentifier("-"),
|
||||
isValid);
|
||||
}
|
||||
|
||||
static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
|
||||
auto isValid = [&](ValueDecl *plusEqualOp) -> bool {
|
||||
auto plusEqual = dyn_cast<FuncDecl>(plusEqualOp);
|
||||
if (!plusEqual || !plusEqual->hasParameterList())
|
||||
return false;
|
||||
auto params = plusEqual->getParameters();
|
||||
if (params->size() != 2)
|
||||
return false;
|
||||
auto lhs = params->get(0);
|
||||
auto rhs = params->get(1);
|
||||
if (rhs->isInOut())
|
||||
return false;
|
||||
auto lhsTy = lhs->getType();
|
||||
auto rhsTy = rhs->getType();
|
||||
if (!lhsTy || !rhsTy)
|
||||
return false;
|
||||
if (rhsTy->getCanonicalType() != distanceTy->getCanonicalType())
|
||||
return false;
|
||||
auto lhsNominal = lhsTy->getAnyNominal();
|
||||
if (lhsNominal != decl)
|
||||
return false;
|
||||
auto returnTy = plusEqual->getResultInterfaceType();
|
||||
if (!returnTy->isVoid())
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
|
||||
return lookupOperator(decl, decl->getASTContext().getIdentifier("+="),
|
||||
isValid);
|
||||
}
|
||||
|
||||
bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) {
|
||||
@@ -111,6 +179,9 @@ void swift::conformToCxxIteratorIfNeeded(
|
||||
assert(clangDecl);
|
||||
ASTContext &ctx = decl->getASTContext();
|
||||
|
||||
if (!ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator))
|
||||
return;
|
||||
|
||||
// We consider a type to be an input iterator if it defines an
|
||||
// `iterator_category` that inherits from `std::input_iterator_tag`, e.g.
|
||||
// `using iterator_category = std::input_iterator_tag`.
|
||||
@@ -134,17 +205,30 @@ void swift::conformToCxxIteratorIfNeeded(
|
||||
if (!underlyingCategoryDecl)
|
||||
return;
|
||||
|
||||
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
|
||||
auto isIteratorCategoryDecl = [&](const clang::CXXRecordDecl *base,
|
||||
StringRef tag) {
|
||||
return base->isInStdNamespace() && base->getIdentifier() &&
|
||||
base->getName() == "input_iterator_tag";
|
||||
base->getName() == tag;
|
||||
};
|
||||
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
|
||||
return isIteratorCategoryDecl(base, "input_iterator_tag");
|
||||
};
|
||||
auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
|
||||
return isIteratorCategoryDecl(base, "random_access_iterator_tag");
|
||||
};
|
||||
|
||||
// Traverse all transitive bases of `underlyingDecl` to check if
|
||||
// it inherits from `std::input_iterator_tag`.
|
||||
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
|
||||
bool isRandomAccessIterator =
|
||||
isRandomAccessIteratorDecl(underlyingCategoryDecl);
|
||||
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
|
||||
if (isInputIteratorDecl(base)) {
|
||||
isInputIterator = true;
|
||||
}
|
||||
if (isRandomAccessIteratorDecl(base)) {
|
||||
isRandomAccessIterator = true;
|
||||
isInputIterator = true;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@@ -183,6 +267,25 @@ void swift::conformToCxxIteratorIfNeeded(
|
||||
pointee->getType());
|
||||
impl.addSynthesizedProtocolAttrs(decl,
|
||||
{KnownProtocolKind::UnsafeCxxInputIterator});
|
||||
if (!isRandomAccessIterator ||
|
||||
!ctx.getProtocol(KnownProtocolKind::UnsafeCxxRandomAccessIterator))
|
||||
return;
|
||||
|
||||
// Try to conform to UnsafeCxxRandomAccessIterator if possible.
|
||||
|
||||
auto minus = dyn_cast<FuncDecl>(getMinusOperator(decl));
|
||||
if (!minus)
|
||||
return;
|
||||
auto distanceTy = minus->getResultInterfaceType();
|
||||
// distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator.
|
||||
|
||||
auto plusEqual = dyn_cast<FuncDecl>(getPlusEqualOperator(decl, distanceTy));
|
||||
if (!plusEqual)
|
||||
return;
|
||||
|
||||
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Distance"), distanceTy);
|
||||
impl.addSynthesizedProtocolAttrs(
|
||||
decl, {KnownProtocolKind::UnsafeCxxRandomAccessIterator});
|
||||
}
|
||||
|
||||
void swift::conformToCxxSequenceIfNeeded(
|
||||
|
||||
@@ -5886,6 +5886,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
|
||||
case KnownProtocolKind::DistributedTargetInvocationResultHandler:
|
||||
case KnownProtocolKind::CxxSequence:
|
||||
case KnownProtocolKind::UnsafeCxxInputIterator:
|
||||
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
|
||||
case KnownProtocolKind::SerialExecutor:
|
||||
case KnownProtocolKind::Sendable:
|
||||
case KnownProtocolKind::UnsafeSendable:
|
||||
|
||||
@@ -240,6 +240,25 @@ struct HasCustomIteratorTag {
|
||||
}
|
||||
};
|
||||
|
||||
struct HasCustomRACIteratorTag {
|
||||
struct CustomTag : public std::random_access_iterator_tag {};
|
||||
|
||||
int value;
|
||||
using iterator_category = CustomTag;
|
||||
const int &operator*() const { return value; }
|
||||
HasCustomRACIteratorTag &operator++() {
|
||||
value++;
|
||||
return *this;
|
||||
}
|
||||
void operator+=(int x) { value += x; }
|
||||
int operator-(const HasCustomRACIteratorTag &x) const {
|
||||
return value - x.value;
|
||||
}
|
||||
bool operator==(const HasCustomRACIteratorTag &other) const {
|
||||
return value == other.value;
|
||||
}
|
||||
};
|
||||
|
||||
struct HasCustomIteratorTagInline {
|
||||
struct iterator_category : public std::input_iterator_tag {};
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ var CxxCollectionTestSuite = TestSuite("CxxCollection")
|
||||
|
||||
// === SimpleCollectionNoSubscript ===
|
||||
|
||||
extension SimpleCollectionNoSubscript.iterator : UnsafeCxxRandomAccessIterator {
|
||||
public typealias Distance = difference_type
|
||||
}
|
||||
extension SimpleCollectionNoSubscript : CxxRandomAccessCollection {
|
||||
}
|
||||
|
||||
@@ -25,9 +22,6 @@ CxxCollectionTestSuite.test("SimpleCollectionNoSubscript as Swift.Collection") {
|
||||
|
||||
// === SimpleCollectionReadOnly ===
|
||||
|
||||
extension SimpleCollectionReadOnly.iterator : UnsafeCxxRandomAccessIterator {
|
||||
public typealias Distance = difference_type
|
||||
}
|
||||
extension SimpleCollectionReadOnly : CxxRandomAccessCollection {
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,26 @@
|
||||
// CHECK: typealias Pointee = Int32
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: struct ConstRACIterator : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
|
||||
// CHECK: var pointee: Int32 { get }
|
||||
// CHECK: func successor() -> ConstRACIterator
|
||||
// CHECK: static func += (lhs: inout ConstRACIterator, v: ConstRACIterator.difference_type)
|
||||
// CHECK: static func - (lhs: ConstRACIterator, other: ConstRACIterator) -> Int32
|
||||
// CHECK: static func == (lhs: ConstRACIterator, other: ConstRACIterator) -> Bool
|
||||
// CHECK: typealias Pointee = Int32
|
||||
// CHECK: typealias Distance = Int32
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: struct ConstRACIteratorRefPlusEq : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
|
||||
// CHECK: var pointee: Int32 { get }
|
||||
// CHECK: func successor() -> ConstRACIterator
|
||||
// CHECK: static func += (lhs: inout ConstRACIteratorRefPlusEq, v: ConstRACIteratorRefPlusEq.difference_type)
|
||||
// CHECK: static func - (lhs: ConstRACIteratorRefPlusEq, other: ConstRACIteratorRefPlusEq) -> Int32
|
||||
// CHECK: static func == (lhs: ConstRACIteratorRefPlusEq, other: ConstRACIteratorRefPlusEq) -> Bool
|
||||
// CHECK: typealias Pointee = Int32
|
||||
// CHECK: typealias Distance = Int32
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: struct ConstIteratorOutOfLineEq : UnsafeCxxInputIterator {
|
||||
// CHECK: var pointee: Int32 { get }
|
||||
// CHECK: func successor() -> ConstIteratorOutOfLineEq
|
||||
@@ -34,6 +54,16 @@
|
||||
// CHECK: typealias Pointee = Int32
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: struct HasCustomRACIteratorTag : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
|
||||
// CHECK: var pointee: Int32 { get }
|
||||
// CHECK: func successor() -> HasCustomRACIteratorTag
|
||||
// CHECK: static func += (lhs: inout HasCustomRACIteratorTag, x: Int32)
|
||||
// CHECK: static func - (lhs: HasCustomRACIteratorTag, x: HasCustomRACIteratorTag) -> Int32
|
||||
// CHECK: static func == (lhs: HasCustomRACIteratorTag, other: HasCustomRACIteratorTag) -> Bool
|
||||
// CHECK: typealias Pointee = Int32
|
||||
// CHECK: typealias Distance = Int32
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: struct HasCustomIteratorTagInline : UnsafeCxxInputIterator {
|
||||
// CHECK: var pointee: Int32 { get }
|
||||
// CHECK: func successor() -> HasCustomIteratorTagInline
|
||||
|
||||
Reference in New Issue
Block a user