Merge pull request #61700 from apple/egorzhdan/cxx-conform-raciter

[cxx-interop] Synthesize conformances to `UnsafeCxxInputIterator`
This commit is contained in:
Egor Zhdan
2022-10-26 18:34:32 +01:00
committed by GitHub
7 changed files with 174 additions and 25 deletions

View File

@@ -107,6 +107,7 @@ PROTOCOL(DistributedTargetInvocationResultHandler)
// C++ Standard Library Overlay:
PROTOCOL(CxxSequence)
PROTOCOL(UnsafeCxxInputIterator)
PROTOCOL(UnsafeCxxRandomAccessIterator)
PROTOCOL(AsyncSequence)
PROTOCOL(AsyncIteratorProtocol)

View File

@@ -1115,6 +1115,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
break;
case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::UnsafeCxxInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
M = getLoadedModule(Id_Cxx);
break;
default:

View File

@@ -53,9 +53,29 @@ getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory);
}
static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
auto id = decl->getASTContext().Id_EqualsOperator;
static ValueDecl *lookupOperator(NominalTypeDecl *decl, Identifier id,
function_ref<bool(ValueDecl *)> isValid) {
// First look for operator declared as a member.
auto memberResults = lookupDirectWithoutExtensions(decl, id);
for (const auto &member : memberResults) {
if (isValid(member))
return member;
}
// If no member operator was found, look for out-of-class definitions in the
// same module.
auto module = decl->getModuleContext();
SmallVector<ValueDecl *> nonMemberResults;
module->lookupValue(id, NLKind::UnqualifiedLookup, nonMemberResults);
for (const auto &nonMember : nonMemberResults) {
if (isValid(nonMember))
return nonMember;
}
return nullptr;
}
static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
auto isValid = [&](ValueDecl *equalEqualOp) -> bool {
auto equalEqual = dyn_cast<FuncDecl>(equalEqualOp);
if (!equalEqual || !equalEqual->hasParameterList())
@@ -78,24 +98,72 @@ static ValueDecl *getEqualEqualOperator(NominalTypeDecl *decl) {
return true;
};
// First look for `func ==` declared as a member.
auto memberResults = lookupDirectWithoutExtensions(decl, id);
for (const auto &member : memberResults) {
if (isValid(member))
return member;
}
return lookupOperator(decl, decl->getASTContext().Id_EqualsOperator, isValid);
}
// If no member `func ==` was found, look for out-of-class definitions in the
// same module.
static ValueDecl *getMinusOperator(NominalTypeDecl *decl) {
auto binaryIntegerProto =
decl->getASTContext().getProtocol(KnownProtocolKind::BinaryInteger);
auto module = decl->getModuleContext();
SmallVector<ValueDecl *> nonMemberResults;
module->lookupValue(id, NLKind::UnqualifiedLookup, nonMemberResults);
for (const auto &nonMember : nonMemberResults) {
if (isValid(nonMember))
return nonMember;
}
return nullptr;
auto isValid = [&](ValueDecl *minusOp) -> bool {
auto minus = dyn_cast<FuncDecl>(minusOp);
if (!minus || !minus->hasParameterList())
return false;
auto params = minus->getParameters();
if (params->size() != 2)
return false;
auto lhs = params->get(0);
auto rhs = params->get(1);
if (lhs->isInOut() || rhs->isInOut())
return false;
auto lhsTy = lhs->getType();
auto rhsTy = rhs->getType();
if (!lhsTy || !rhsTy)
return false;
auto lhsNominal = lhsTy->getAnyNominal();
auto rhsNominal = rhsTy->getAnyNominal();
if (lhsNominal != rhsNominal || lhsNominal != decl)
return false;
auto returnTy = minus->getResultInterfaceType();
if (!module->conformsToProtocol(returnTy, binaryIntegerProto))
return false;
return true;
};
return lookupOperator(decl, decl->getASTContext().getIdentifier("-"),
isValid);
}
static ValueDecl *getPlusEqualOperator(NominalTypeDecl *decl, Type distanceTy) {
auto isValid = [&](ValueDecl *plusEqualOp) -> bool {
auto plusEqual = dyn_cast<FuncDecl>(plusEqualOp);
if (!plusEqual || !plusEqual->hasParameterList())
return false;
auto params = plusEqual->getParameters();
if (params->size() != 2)
return false;
auto lhs = params->get(0);
auto rhs = params->get(1);
if (rhs->isInOut())
return false;
auto lhsTy = lhs->getType();
auto rhsTy = rhs->getType();
if (!lhsTy || !rhsTy)
return false;
if (rhsTy->getCanonicalType() != distanceTy->getCanonicalType())
return false;
auto lhsNominal = lhsTy->getAnyNominal();
if (lhsNominal != decl)
return false;
auto returnTy = plusEqual->getResultInterfaceType();
if (!returnTy->isVoid())
return false;
return true;
};
return lookupOperator(decl, decl->getASTContext().getIdentifier("+="),
isValid);
}
bool swift::isIterator(const clang::CXXRecordDecl *clangDecl) {
@@ -111,6 +179,9 @@ void swift::conformToCxxIteratorIfNeeded(
assert(clangDecl);
ASTContext &ctx = decl->getASTContext();
if (!ctx.getProtocol(KnownProtocolKind::UnsafeCxxInputIterator))
return;
// We consider a type to be an input iterator if it defines an
// `iterator_category` that inherits from `std::input_iterator_tag`, e.g.
// `using iterator_category = std::input_iterator_tag`.
@@ -134,17 +205,30 @@ void swift::conformToCxxIteratorIfNeeded(
if (!underlyingCategoryDecl)
return;
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
auto isIteratorCategoryDecl = [&](const clang::CXXRecordDecl *base,
StringRef tag) {
return base->isInStdNamespace() && base->getIdentifier() &&
base->getName() == "input_iterator_tag";
base->getName() == tag;
};
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "input_iterator_tag");
};
auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "random_access_iterator_tag");
};
// Traverse all transitive bases of `underlyingDecl` to check if
// it inherits from `std::input_iterator_tag`.
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
bool isRandomAccessIterator =
isRandomAccessIteratorDecl(underlyingCategoryDecl);
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
if (isInputIteratorDecl(base)) {
isInputIterator = true;
}
if (isRandomAccessIteratorDecl(base)) {
isRandomAccessIterator = true;
isInputIterator = true;
return false;
}
return true;
@@ -183,6 +267,25 @@ void swift::conformToCxxIteratorIfNeeded(
pointee->getType());
impl.addSynthesizedProtocolAttrs(decl,
{KnownProtocolKind::UnsafeCxxInputIterator});
if (!isRandomAccessIterator ||
!ctx.getProtocol(KnownProtocolKind::UnsafeCxxRandomAccessIterator))
return;
// Try to conform to UnsafeCxxRandomAccessIterator if possible.
auto minus = dyn_cast<FuncDecl>(getMinusOperator(decl));
if (!minus)
return;
auto distanceTy = minus->getResultInterfaceType();
// distanceTy conforms to BinaryInteger, this is ensured by getMinusOperator.
auto plusEqual = dyn_cast<FuncDecl>(getPlusEqualOperator(decl, distanceTy));
if (!plusEqual)
return;
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Distance"), distanceTy);
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxRandomAccessIterator});
}
void swift::conformToCxxSequenceIfNeeded(

View File

@@ -5886,6 +5886,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::DistributedTargetInvocationResultHandler:
case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::UnsafeCxxInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
case KnownProtocolKind::SerialExecutor:
case KnownProtocolKind::Sendable:
case KnownProtocolKind::UnsafeSendable:

View File

@@ -240,6 +240,25 @@ struct HasCustomIteratorTag {
}
};
struct HasCustomRACIteratorTag {
struct CustomTag : public std::random_access_iterator_tag {};
int value;
using iterator_category = CustomTag;
const int &operator*() const { return value; }
HasCustomRACIteratorTag &operator++() {
value++;
return *this;
}
void operator+=(int x) { value += x; }
int operator-(const HasCustomRACIteratorTag &x) const {
return value - x.value;
}
bool operator==(const HasCustomRACIteratorTag &other) const {
return value == other.value;
}
};
struct HasCustomIteratorTagInline {
struct iterator_category : public std::input_iterator_tag {};

View File

@@ -11,9 +11,6 @@ var CxxCollectionTestSuite = TestSuite("CxxCollection")
// === SimpleCollectionNoSubscript ===
extension SimpleCollectionNoSubscript.iterator : UnsafeCxxRandomAccessIterator {
public typealias Distance = difference_type
}
extension SimpleCollectionNoSubscript : CxxRandomAccessCollection {
}
@@ -25,9 +22,6 @@ CxxCollectionTestSuite.test("SimpleCollectionNoSubscript as Swift.Collection") {
// === SimpleCollectionReadOnly ===
extension SimpleCollectionReadOnly.iterator : UnsafeCxxRandomAccessIterator {
public typealias Distance = difference_type
}
extension SimpleCollectionReadOnly : CxxRandomAccessCollection {
}

View File

@@ -7,6 +7,26 @@
// CHECK: typealias Pointee = Int32
// CHECK: }
// CHECK: struct ConstRACIterator : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: var pointee: Int32 { get }
// CHECK: func successor() -> ConstRACIterator
// CHECK: static func += (lhs: inout ConstRACIterator, v: ConstRACIterator.difference_type)
// CHECK: static func - (lhs: ConstRACIterator, other: ConstRACIterator) -> Int32
// CHECK: static func == (lhs: ConstRACIterator, other: ConstRACIterator) -> Bool
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }
// CHECK: struct ConstRACIteratorRefPlusEq : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: var pointee: Int32 { get }
// CHECK: func successor() -> ConstRACIterator
// CHECK: static func += (lhs: inout ConstRACIteratorRefPlusEq, v: ConstRACIteratorRefPlusEq.difference_type)
// CHECK: static func - (lhs: ConstRACIteratorRefPlusEq, other: ConstRACIteratorRefPlusEq) -> Int32
// CHECK: static func == (lhs: ConstRACIteratorRefPlusEq, other: ConstRACIteratorRefPlusEq) -> Bool
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }
// CHECK: struct ConstIteratorOutOfLineEq : UnsafeCxxInputIterator {
// CHECK: var pointee: Int32 { get }
// CHECK: func successor() -> ConstIteratorOutOfLineEq
@@ -34,6 +54,16 @@
// CHECK: typealias Pointee = Int32
// CHECK: }
// CHECK: struct HasCustomRACIteratorTag : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: var pointee: Int32 { get }
// CHECK: func successor() -> HasCustomRACIteratorTag
// CHECK: static func += (lhs: inout HasCustomRACIteratorTag, x: Int32)
// CHECK: static func - (lhs: HasCustomRACIteratorTag, x: HasCustomRACIteratorTag) -> Int32
// CHECK: static func == (lhs: HasCustomRACIteratorTag, other: HasCustomRACIteratorTag) -> Bool
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }
// CHECK: struct HasCustomIteratorTagInline : UnsafeCxxInputIterator {
// CHECK: var pointee: Int32 { get }
// CHECK: func successor() -> HasCustomIteratorTagInline