Merge pull request #61700 from apple/egorzhdan/cxx-conform-raciter

[cxx-interop] Synthesize conformances to `UnsafeCxxInputIterator`
This commit is contained in:
Egor Zhdan
2022-10-26 18:34:32 +01:00
committed by GitHub
7 changed files with 174 additions and 25 deletions

View File

@@ -107,6 +107,7 @@ PROTOCOL(DistributedTargetInvocationResultHandler)
// C++ Standard Library Overlay: // C++ Standard Library Overlay:
PROTOCOL(CxxSequence) PROTOCOL(CxxSequence)
PROTOCOL(UnsafeCxxInputIterator) PROTOCOL(UnsafeCxxInputIterator)
PROTOCOL(UnsafeCxxRandomAccessIterator)
PROTOCOL(AsyncSequence) PROTOCOL(AsyncSequence)
PROTOCOL(AsyncIteratorProtocol) PROTOCOL(AsyncIteratorProtocol)

View File

@@ -1115,6 +1115,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
break; break;
case KnownProtocolKind::CxxSequence: case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::UnsafeCxxInputIterator: case KnownProtocolKind::UnsafeCxxInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
M = getLoadedModule(Id_Cxx); M = getLoadedModule(Id_Cxx);
break; break;
default: default:

View File

@@ -53,9 +53,29 @@ getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory); return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory);
} }
static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) { static ValueDecl *lookupOperator(NominalTypeDecl *decl, Identifier id,
auto id = decl->getASTContext().Id_EqualsOperator; function_ref<bool(ValueDecl *)> isValid) {
// First look for operator declared as a member.
auto memberResults = lookupDirectWithoutExtensions(decl, id);
for (const auto &member : memberResults) {
if (isValid(member))
return member;
}
// If no member operator was found, look for out-of-class definitions in the
// same module.
auto module = decl->getModuleContext();
SmallVector<ValueDecl *> nonMemberResults;
module->lookupValue(id, NLKind::UnqualifiedLookup, nonMemberResults);
for (const auto &nonMember : nonMemberResults) {
if (isValid(nonMember))
return nonMember;
}
return nullptr;
}
static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
auto isValid = [&](ValueDecl *equalEqualOp) -> bool { auto isValid = [&](ValueDecl *equalEqualOp) -> bool {
auto equalEqual = dyn_cast<FuncDecl>(equalEqualOp); auto equalEqual = dyn_cast<FuncDecl>(equalEqualOp);
if (!equalEqual || !equalEqual->hasParameterList()) if (!equalEqual || !equalEqual->hasParameterList())
@@ -78,24 +98,72 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
return true; return true;
}; };
// First look for `func ==` declared as a member. return lookupOperator(decl, decl->getASTContext().Id_EqualsOperator, isValid);
auto memberResults = lookupDirectWithoutExtensions(decl, id); }
for (const auto &member : memberResults) {
if (isValid(member))
return member;
}
// If no member `func ==` was found, look for out-of-class definitions in the static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
// same module. auto binaryIntegerProto =
decl->getASTContext().getProtocol(KnownProtocolKind::BinaryInteger);
auto module = decl->getModuleContext(); auto module = decl->getModuleContext();
SmallVector<ValueDecl *> nonMemberResults;
module->lookupValue(id, NLKind::UnqualifiedLookup, nonMemberResults);
for (const auto &nonMember : nonMemberResults) {
if (isValid(nonMember))
return nonMember;
}
return nullptr; auto isValid = [&](ValueDecl *minusOp) -> bool {
auto minus = dyn_cast<FuncDecl>(minusOp);
if (!minus || !minus->hasParameterList())
return false;
auto params = minus->getParameters();
if (params->size() != 2)
return false;
auto lhs = params->get(0);
auto rhs = params->get(1);
if (lhs->isInOut() || rhs->isInOut())
return false;
auto lhsTy = lhs->getType();
auto rhsTy = rhs->getType();
if (!lhsTy || !rhsTy)
return false;
auto lhsNominal = lhsTy->getAnyNominal();
auto rhsNominal = rhsTy->getAnyNominal();
if (lhsNominal != rhsNominal || lhsNominal != decl)
return false;
auto returnTy = minus->getResultInterfaceType();
if (!module->conformsToProtocol(returnTy, binaryIntegerProto))
return false;
return true;
};
return lookupOperator(decl, decl->getASTContext().getIdentifier("-"),
isValid);
}
static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
auto isValid = [&](ValueDecl *plusEqualOp) -> bool {
auto plusEqual = dyn_cast<FuncDecl>(plusEqualOp);
if (!plusEqual || !plusEqual->hasParameterList())
return false;
auto params = plusEqual->getParameters();
if (params->size() != 2)
return false;
auto lhs = params->get(0);
auto rhs = params->get(1);
if (rhs->isInOut())
return false;
auto lhsTy = lhs->getType();
auto rhsTy = rhs->getType();
if (!lhsTy || !rhsTy)
return false;
if (rhsTy->getCanonicalType() != distanceTy->getCanonicalType())
return false;
auto lhsNominal = lhsTy->getAnyNominal();
if (lhsNominal != decl)
return false;
auto returnTy = plusEqual->getResultInterfaceType();
if (!returnTy->isVoid())
return false;
return true;
};
return lookupOperator(decl, decl->getASTContext().getIdentifier("+="),
isValid);
} }
bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) { bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) {
@@ -111,6 +179,9 @@ void swift::conformToCxxIteratorIfNeeded(
assert(clangDecl); assert(clangDecl);
ASTContext &ctx = decl->getASTContext(); ASTContext &ctx = decl->getASTContext();
if (!ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator))
return;
// We consider a type to be an input iterator if it defines an // We consider a type to be an input iterator if it defines an
// `iterator_category` that inherits from `std::input_iterator_tag`, e.g. // `iterator_category` that inherits from `std::input_iterator_tag`, e.g.
// `using iterator_category = std::input_iterator_tag`. // `using iterator_category = std::input_iterator_tag`.
@@ -134,17 +205,30 @@ void swift::conformToCxxIteratorIfNeeded(
if (!underlyingCategoryDecl) if (!underlyingCategoryDecl)
return; return;
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) { auto isIteratorCategoryDecl = [&](const clang::CXXRecordDecl *base,
StringRef tag) {
return base->isInStdNamespace() && base->getIdentifier() && return base->isInStdNamespace() && base->getIdentifier() &&
base->getName() == "input_iterator_tag"; base->getName() == tag;
};
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "input_iterator_tag");
};
auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "random_access_iterator_tag");
}; };
// Traverse all transitive bases of `underlyingDecl` to check if // Traverse all transitive bases of `underlyingDecl` to check if
// it inherits from `std::input_iterator_tag`. // it inherits from `std::input_iterator_tag`.
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl); bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
bool isRandomAccessIterator =
isRandomAccessIteratorDecl(underlyingCategoryDecl);
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) { underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
if (isInputIteratorDecl(base)) { if (isInputIteratorDecl(base)) {
isInputIterator = true; isInputIterator = true;
}
if (isRandomAccessIteratorDecl(base)) {
isRandomAccessIterator = true;
isInputIterator = true;
return false; return false;
} }
return true; return true;
@@ -183,6 +267,25 @@ void swift::conformToCxxIteratorIfNeeded(
pointee->getType()); pointee->getType());
impl.addSynthesizedProtocolAttrs(decl, impl.addSynthesizedProtocolAttrs(decl,
{KnownProtocolKind::UnsafeCxxInputIterator}); {KnownProtocolKind::UnsafeCxxInputIterator});
if (!isRandomAccessIterator ||
!ctx.getProtocol(KnownProtocolKind::UnsafeCxxRandomAccessIterator))
return;
// Try to conform to UnsafeCxxRandomAccessIterator if possible.
auto minus = dyn_cast<FuncDecl>(getMinusOperator(decl));
if (!minus)
return;
auto distanceTy = minus->getResultInterfaceType();
// distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator.
auto plusEqual = dyn_cast<FuncDecl>(getPlusEqualOperator(decl, distanceTy));
if (!plusEqual)
return;
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Distance"), distanceTy);
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxRandomAccessIterator});
} }
void swift::conformToCxxSequenceIfNeeded( void swift::conformToCxxSequenceIfNeeded(

View File

@@ -5886,6 +5886,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::DistributedTargetInvocationResultHandler: case KnownProtocolKind::DistributedTargetInvocationResultHandler:
case KnownProtocolKind::CxxSequence: case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::UnsafeCxxInputIterator: case KnownProtocolKind::UnsafeCxxInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
case KnownProtocolKind::SerialExecutor: case KnownProtocolKind::SerialExecutor:
case KnownProtocolKind::Sendable: case KnownProtocolKind::Sendable:
case KnownProtocolKind::UnsafeSendable: case KnownProtocolKind::UnsafeSendable:

View File

@@ -240,6 +240,25 @@ struct HasCustomIteratorTag {
} }
}; };
struct HasCustomRACIteratorTag {
struct CustomTag : public std::random_access_iterator_tag {};
int value;
using iterator_category = CustomTag;
const int &operator*() const { return value; }
HasCustomRACIteratorTag &operator++() {
value++;
return *this;
}
void operator+=(int x) { value += x; }
int operator-(const HasCustomRACIteratorTag &x) const {
return value - x.value;
}
bool operator==(const HasCustomRACIteratorTag &other) const {
return value == other.value;
}
};
struct HasCustomIteratorTagInline { struct HasCustomIteratorTagInline {
struct iterator_category : public std::input_iterator_tag {}; struct iterator_category : public std::input_iterator_tag {};

View File

@@ -11,9 +11,6 @@ var CxxCollectionTestSuite = TestSuite("CxxCollection")
// === SimpleCollectionNoSubscript === // === SimpleCollectionNoSubscript ===
extension SimpleCollectionNoSubscript.iterator : UnsafeCxxRandomAccessIterator {
public typealias Distance = difference_type
}
extension SimpleCollectionNoSubscript : CxxRandomAccessCollection { extension SimpleCollectionNoSubscript : CxxRandomAccessCollection {
} }
@@ -25,9 +22,6 @@ CxxCollectionTestSuite.test("SimpleCollectionNoSubscript as Swift.Collection") {
// === SimpleCollectionReadOnly === // === SimpleCollectionReadOnly ===
extension SimpleCollectionReadOnly.iterator : UnsafeCxxRandomAccessIterator {
public typealias Distance = difference_type
}
extension SimpleCollectionReadOnly : CxxRandomAccessCollection { extension SimpleCollectionReadOnly : CxxRandomAccessCollection {
} }

View File

@@ -7,6 +7,26 @@
// CHECK: typealias Pointee = Int32 // CHECK: typealias Pointee = Int32
// CHECK: } // CHECK: }
// CHECK: struct ConstRACIterator : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: var pointee: Int32 { get }
// CHECK: func successor() -> ConstRACIterator
// CHECK: static func += (lhs: inout ConstRACIterator, v: ConstRACIterator.difference_type)
// CHECK: static func - (lhs: ConstRACIterator, other: ConstRACIterator) -> Int32
// CHECK: static func == (lhs: ConstRACIterator, other: ConstRACIterator) -> Bool
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }
// CHECK: struct ConstRACIteratorRefPlusEq : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: var pointee: Int32 { get }
// CHECK: func successor() -> ConstRACIterator
// CHECK: static func += (lhs: inout ConstRACIteratorRefPlusEq, v: ConstRACIteratorRefPlusEq.difference_type)
// CHECK: static func - (lhs: ConstRACIteratorRefPlusEq, other: ConstRACIteratorRefPlusEq) -> Int32
// CHECK: static func == (lhs: ConstRACIteratorRefPlusEq, other: ConstRACIteratorRefPlusEq) -> Bool
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }
// CHECK: struct ConstIteratorOutOfLineEq : UnsafeCxxInputIterator { // CHECK: struct ConstIteratorOutOfLineEq : UnsafeCxxInputIterator {
// CHECK: var pointee: Int32 { get } // CHECK: var pointee: Int32 { get }
// CHECK: func successor() -> ConstIteratorOutOfLineEq // CHECK: func successor() -> ConstIteratorOutOfLineEq
@@ -34,6 +54,16 @@
// CHECK: typealias Pointee = Int32 // CHECK: typealias Pointee = Int32
// CHECK: } // CHECK: }
// CHECK: struct HasCustomRACIteratorTag : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: var pointee: Int32 { get }
// CHECK: func successor() -> HasCustomRACIteratorTag
// CHECK: static func += (lhs: inout HasCustomRACIteratorTag, x: Int32)
// CHECK: static func - (lhs: HasCustomRACIteratorTag, x: HasCustomRACIteratorTag) -> Int32
// CHECK: static func == (lhs: HasCustomRACIteratorTag, other: HasCustomRACIteratorTag) -> Bool
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }
// CHECK: struct HasCustomIteratorTagInline : UnsafeCxxInputIterator { // CHECK: struct HasCustomIteratorTagInline : UnsafeCxxInputIterator {
// CHECK: var pointee: Int32 { get } // CHECK: var pointee: Int32 { get }
// CHECK: func successor() -> HasCustomIteratorTagInline // CHECK: func successor() -> HasCustomIteratorTagInline