[cxx-interop] Add UnsafeCxxContiguousIterator & UnsafeCxxMutableContiguousIterator protocols

This adds a pair of Swift protocols that represents C++ iterator types conforming to `std::contiguous_iterator_tag` requirements. These are random access iterators that guarantee that the values are stored in consequent memory addresses.

This will be used to optimize usage of C++ containers such as `std::vector` from Swift, for instance, by providing an overload of `withContiguousStorageIfAvailable` for contiguous containers.

rdar://137877849
This commit is contained in:
Egor Zhdan
2024-10-14 16:02:57 +01:00
parent a9d59034b3
commit 3a200deee9
8 changed files with 294 additions and 63 deletions

View File

@@ -142,6 +142,8 @@ PROTOCOL(UnsafeCxxInputIterator)
PROTOCOL(UnsafeCxxMutableInputIterator)
PROTOCOL(UnsafeCxxRandomAccessIterator)
PROTOCOL(UnsafeCxxMutableRandomAccessIterator)
PROTOCOL(UnsafeCxxContiguousIterator)
PROTOCOL(UnsafeCxxMutableContiguousIterator)
PROTOCOL(AsyncSequence)
PROTOCOL(AsyncIteratorProtocol)

View File

@@ -1444,6 +1444,8 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
case KnownProtocolKind::UnsafeCxxMutableRandomAccessIterator:
case KnownProtocolKind::UnsafeCxxContiguousIterator:
case KnownProtocolKind::UnsafeCxxMutableContiguousIterator:
M = getLoadedModule(Id_Cxx);
break;
case KnownProtocolKind::Copyable:

View File

@@ -462,12 +462,16 @@ void swift::conformToCxxIteratorIfNeeded(
auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "random_access_iterator_tag");
};
auto isContiguousIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "contiguous_iterator_tag"); // C++20
};
// Traverse all transitive bases of `underlyingDecl` to check if
// it inherits from `std::input_iterator_tag`.
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
bool isRandomAccessIterator =
isRandomAccessIteratorDecl(underlyingCategoryDecl);
bool isContiguousIterator = isContiguousIteratorDecl(underlyingCategoryDecl);
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
if (isInputIteratorDecl(base)) {
isInputIterator = true;
@@ -475,6 +479,11 @@ void swift::conformToCxxIteratorIfNeeded(
if (isRandomAccessIteratorDecl(base)) {
isRandomAccessIterator = true;
isInputIterator = true;
}
if (isContiguousIteratorDecl(base)) {
isContiguousIterator = true;
isRandomAccessIterator = true;
isInputIterator = true;
return false;
}
return true;
@@ -594,6 +603,15 @@ void swift::conformToCxxIteratorIfNeeded(
else
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxRandomAccessIterator});
if (isContiguousIterator) {
if (pointeeSettable)
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxMutableContiguousIterator});
else
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxContiguousIterator});
}
}
void swift::conformToCxxConvertibleToBoolIfNeeded(

View File

@@ -6967,6 +6967,8 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
case KnownProtocolKind::UnsafeCxxMutableRandomAccessIterator:
case KnownProtocolKind::UnsafeCxxContiguousIterator:
case KnownProtocolKind::UnsafeCxxMutableContiguousIterator:
case KnownProtocolKind::Executor:
case KnownProtocolKind::SerialExecutor:
case KnownProtocolKind::TaskExecutor:

View File

@@ -87,3 +87,15 @@ public protocol UnsafeCxxMutableRandomAccessIterator:
UnsafeCxxRandomAccessIterator, UnsafeCxxMutableInputIterator {}
extension UnsafeMutablePointer: UnsafeCxxMutableRandomAccessIterator {}
/// Bridged C++ iterator that allows traversing elements of a random access
/// collection that are stored in contiguous memory segments.
///
/// Mostly useful for optimizing operations with containers that conform to
/// `CxxRandomAccessCollection` and should not generally be used directly.
///
/// - SeeAlso: https://en.cppreference.com/w/cpp/named_req/ContiguousIterator
public protocol UnsafeCxxContiguousIterator: UnsafeCxxRandomAccessIterator {}
public protocol UnsafeCxxMutableContiguousIterator:
UnsafeCxxContiguousIterator, UnsafeCxxMutableRandomAccessIterator {}

View File

@@ -286,6 +286,229 @@ struct HasTypedefIteratorTag {
}
};
struct MutableRACIterator {
private:
int *value;
public:
struct iterator_category : std::random_access_iterator_tag,
std::output_iterator_tag {};
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;
MutableRACIterator(int *value) : value(value) {}
MutableRACIterator(const MutableRACIterator &other) = default;
const int &operator*() const { return *value; }
int &operator*() { return *value; }
MutableRACIterator &operator++() {
value++;
return *this;
}
MutableRACIterator operator++(int) {
auto tmp = MutableRACIterator(value);
value++;
return tmp;
}
void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
MutableRACIterator operator+(difference_type v) const {
return MutableRACIterator(value + v);
}
MutableRACIterator operator-(difference_type v) const {
return MutableRACIterator(value - v);
}
friend MutableRACIterator operator+(difference_type v,
const MutableRACIterator &it) {
return it + v;
}
int operator-(const MutableRACIterator &other) const {
return value - other.value;
}
bool operator<(const MutableRACIterator &other) const {
return value < other.value;
}
bool operator==(const MutableRACIterator &other) const {
return value == other.value;
}
bool operator!=(const MutableRACIterator &other) const {
return value != other.value;
}
};
#if __cplusplus >= 202002L
struct ConstContiguousIterator {
private:
const int *value;
public:
using iterator_category = std::contiguous_iterator_tag;
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;
ConstContiguousIterator(const int *value) : value(value) {}
ConstContiguousIterator(const ConstContiguousIterator &other) = default;
const int &operator*() const { return *value; }
ConstContiguousIterator &operator++() {
value++;
return *this;
}
ConstContiguousIterator operator++(int) {
auto tmp = ConstContiguousIterator(value);
value++;
return tmp;
}
void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
ConstContiguousIterator operator+(difference_type v) const {
return ConstContiguousIterator(value + v);
}
ConstContiguousIterator operator-(difference_type v) const {
return ConstContiguousIterator(value - v);
}
friend ConstContiguousIterator operator+(difference_type v,
const ConstContiguousIterator &it) {
return it + v;
}
int operator-(const ConstContiguousIterator &other) const {
return value - other.value;
}
bool operator<(const ConstContiguousIterator &other) const {
return value < other.value;
}
bool operator==(const ConstContiguousIterator &other) const {
return value == other.value;
}
bool operator!=(const ConstContiguousIterator &other) const {
return value != other.value;
}
};
struct HasCustomContiguousIteratorTag {
private:
const int *value;
public:
struct CustomTag : std::contiguous_iterator_tag {};
using iterator_category = CustomTag;
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;
HasCustomContiguousIteratorTag(const int *value) : value(value) {}
HasCustomContiguousIteratorTag(const HasCustomContiguousIteratorTag &other) =
default;
const int &operator*() const { return *value; }
HasCustomContiguousIteratorTag &operator++() {
value++;
return *this;
}
HasCustomContiguousIteratorTag operator++(int) {
auto tmp = HasCustomContiguousIteratorTag(value);
value++;
return tmp;
}
void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
HasCustomContiguousIteratorTag operator+(difference_type v) const {
return HasCustomContiguousIteratorTag(value + v);
}
HasCustomContiguousIteratorTag operator-(difference_type v) const {
return HasCustomContiguousIteratorTag(value - v);
}
friend HasCustomContiguousIteratorTag
operator+(difference_type v, const HasCustomContiguousIteratorTag &it) {
return it + v;
}
int operator-(const HasCustomContiguousIteratorTag &other) const {
return value - other.value;
}
bool operator<(const HasCustomContiguousIteratorTag &other) const {
return value < other.value;
}
bool operator==(const HasCustomContiguousIteratorTag &other) const {
return value == other.value;
}
bool operator!=(const HasCustomContiguousIteratorTag &other) const {
return value != other.value;
}
};
struct MutableContiguousIterator {
private:
int *value;
public:
using iterator_category = std::contiguous_iterator_tag;
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;
MutableContiguousIterator(int *value) : value(value) {}
MutableContiguousIterator(const MutableContiguousIterator &other) = default;
const int &operator*() const { return *value; }
int &operator*() { return *value; }
MutableContiguousIterator &operator++() {
value++;
return *this;
}
MutableContiguousIterator operator++(int) {
auto tmp = MutableContiguousIterator(value);
value++;
return tmp;
}
void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
MutableContiguousIterator operator+(difference_type v) const {
return MutableContiguousIterator(value + v);
}
MutableContiguousIterator operator-(difference_type v) const {
return MutableContiguousIterator(value - v);
}
friend MutableContiguousIterator
operator+(difference_type v, const MutableContiguousIterator &it) {
return it + v;
}
int operator-(const MutableContiguousIterator &other) const {
return value - other.value;
}
bool operator<(const MutableContiguousIterator &other) const {
return value < other.value;
}
bool operator==(const MutableContiguousIterator &other) const {
return value == other.value;
}
bool operator!=(const MutableContiguousIterator &other) const {
return value != other.value;
}
};
#endif
// MARK: Types that are not actually iterators
struct HasNoIteratorCategory {
@@ -916,62 +1139,6 @@ public:
}
};
struct MutableRACIterator {
private:
int *value;
public:
struct iterator_category : std::random_access_iterator_tag,
std::output_iterator_tag {};
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;
MutableRACIterator(int *value) : value(value) {}
MutableRACIterator(const MutableRACIterator &other) = default;
const int &operator*() const { return *value; }
int &operator*() { return *value; }
MutableRACIterator &operator++() {
value++;
return *this;
}
MutableRACIterator operator++(int) {
auto tmp = MutableRACIterator(value);
value++;
return tmp;
}
void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
MutableRACIterator operator+(difference_type v) const {
return MutableRACIterator(value + v);
}
MutableRACIterator operator-(difference_type v) const {
return MutableRACIterator(value - v);
}
friend MutableRACIterator operator+(difference_type v,
const MutableRACIterator &it) {
return it + v;
}
int operator-(const MutableRACIterator &other) const {
return value - other.value;
}
bool operator<(const MutableRACIterator &other) const {
return value < other.value;
}
bool operator==(const MutableRACIterator &other) const {
return value == other.value;
}
bool operator!=(const MutableRACIterator &other) const {
return value != other.value;
}
};
/// clang::StmtIteratorBase
class ProtectedIteratorBase {
protected:

View File

@@ -0,0 +1,28 @@
// RUN: %target-swift-ide-test -print-module -module-to-print=CustomIterator -source-filename=x -I %S/Inputs -cxx-interoperability-mode=swift-6 -Xcc -std=c++20 | %FileCheck %s
// RUN: %target-swift-ide-test -print-module -module-to-print=CustomIterator -source-filename=x -I %S/Inputs -cxx-interoperability-mode=upcoming-swift -Xcc -std=c++20 | %FileCheck %s
// Ubuntu 20.04 ships with an old version of libstdc++, which does not provide
// std::contiguous_iterator_tag from C++20.
// UNSUPPORTED: LinuxDistribution=ubuntu-20.04
// UNSUPPORTED: LinuxDistribution=amzn-2
// CHECK: struct ConstContiguousIterator : UnsafeCxxContiguousIterator, UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: func successor() -> ConstContiguousIterator
// CHECK: var pointee: Int32
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }
// CHECK: struct HasCustomContiguousIteratorTag : UnsafeCxxContiguousIterator, UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: func successor() -> HasCustomContiguousIteratorTag
// CHECK: var pointee: Int32
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }
// CHECK: struct MutableContiguousIterator : UnsafeCxxMutableContiguousIterator, UnsafeCxxMutableRandomAccessIterator, UnsafeCxxMutableInputIterator {
// CHECK: func successor() -> MutableContiguousIterator
// CHECK: var pointee: Int32
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }

View File

@@ -80,6 +80,13 @@
// CHECK: static func == (lhs: HasTypedefIteratorTag, other: HasTypedefIteratorTag) -> Bool
// CHECK: }
// CHECK: struct MutableRACIterator : UnsafeCxxMutableRandomAccessIterator, UnsafeCxxMutableInputIterator {
// CHECK: func successor() -> MutableRACIterator
// CHECK: var pointee: Int32
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }
// CHECK-NOT: struct HasNoIteratorCategory : UnsafeCxxInputIterator
// CHECK-NOT: struct HasInvalidIteratorCategory : UnsafeCxxInputIterator
// CHECK-NOT: struct HasNoEqualEqual : UnsafeCxxInputIterator
@@ -139,10 +146,3 @@
// CHECK: var pointee: Int32 { get nonmutating set }
// CHECK: typealias Pointee = Int32
// CHECK: }
// CHECK: struct MutableRACIterator : UnsafeCxxMutableRandomAccessIterator, UnsafeCxxMutableInputIterator {
// CHECK: func successor() -> MutableRACIterator
// CHECK: var pointee: Int32
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }