[cxx-interop] Add UnsafeCxxMutableInputIterator protocol

This is an inheritor of the existing `UnsafeCxxInputIterator` protocol, with the only difference being the ability to mutate `var pointee` via a non-const `operator*()`.

This is needed to support mutable subscripts for `std::map` via `CxxDictionary`.

rdar://105399019
This commit is contained in:
Egor Zhdan
2023-07-26 16:15:49 +01:00
parent 61d73d2b6f
commit 8d7d0efe13
7 changed files with 84 additions and 3 deletions

View File

@@ -116,6 +116,7 @@ PROTOCOL(CxxRandomAccessCollection)
PROTOCOL(CxxSequence)
PROTOCOL(CxxUniqueSet)
PROTOCOL(UnsafeCxxInputIterator)
PROTOCOL(UnsafeCxxMutableInputIterator)
PROTOCOL(UnsafeCxxRandomAccessIterator)
PROTOCOL(AsyncSequence)

View File

@@ -1140,6 +1140,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::CxxUniqueSet:
case KnownProtocolKind::UnsafeCxxInputIterator:
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
M = getLoadedModule(Id_Cxx);
break;

View File

@@ -432,6 +432,11 @@ void swift::conformToCxxIteratorIfNeeded(
if (!pointee || pointee->isGetterMutating() || pointee->getType()->hasError())
return;
// Check if `var pointee: Pointee` is settable. This is required for the
// conformance to UnsafeCxxMutableInputIterator but is not necessary for
// UnsafeCxxInputIterator.
bool pointeeSettable = pointee->isSettable(nullptr);
// Check if present: `func successor() -> Self`
auto successorId = ctx.getIdentifier("successor");
auto successor =
@@ -469,8 +474,13 @@ void swift::conformToCxxIteratorIfNeeded(
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Pointee"),
pointee->getType());
impl.addSynthesizedProtocolAttrs(decl,
{KnownProtocolKind::UnsafeCxxInputIterator});
if (pointeeSettable)
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxMutableInputIterator});
else
impl.addSynthesizedProtocolAttrs(
decl, {KnownProtocolKind::UnsafeCxxInputIterator});
if (!isRandomAccessIterator ||
!ctx.getProtocol(KnownProtocolKind::UnsafeCxxRandomAccessIterator))
return;

View File

@@ -6311,6 +6311,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::CxxSequence:
case KnownProtocolKind::CxxUniqueSet:
case KnownProtocolKind::UnsafeCxxInputIterator:
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
case KnownProtocolKind::Executor:
case KnownProtocolKind::SerialExecutor:

View File

@@ -57,6 +57,10 @@ extension Optional: UnsafeCxxInputIterator where Wrapped: UnsafeCxxInputIterator
}
}
public protocol UnsafeCxxMutableInputIterator: UnsafeCxxInputIterator {
override var pointee: Pointee { get set }
}
/// Bridged C++ iterator that allows computing the distance between two of its
/// instances, and advancing an instance by a given number of elements.
///

View File

@@ -866,4 +866,60 @@ public:
}
};
struct MutableRACIterator {
private:
int value;
public:
struct iterator_category : std::random_access_iterator_tag,
std::output_iterator_tag {};
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;
MutableRACIterator(int value) : value(value) {}
MutableRACIterator(const MutableRACIterator &other) = default;
const int &operator*() const { return value; }
int &operator*() { return value; }
MutableRACIterator &operator++() {
value++;
return *this;
}
MutableRACIterator operator++(int) {
auto tmp = MutableRACIterator(value);
value++;
return tmp;
}
void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
MutableRACIterator operator+(difference_type v) const {
return MutableRACIterator(value + v);
}
MutableRACIterator operator-(difference_type v) const {
return MutableRACIterator(value - v);
}
friend MutableRACIterator operator+(difference_type v,
const MutableRACIterator &it) {
return it + v;
}
int operator-(const MutableRACIterator &other) const {
return value - other.value;
}
bool operator<(const MutableRACIterator &other) const {
return value < other.value;
}
bool operator==(const MutableRACIterator &other) const {
return value == other.value;
}
bool operator!=(const MutableRACIterator &other) const {
return value != other.value;
}
};
#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_CUSTOM_ITERATOR_H

View File

@@ -126,7 +126,15 @@
// CHECK: struct InheritedTemplatedConstRACIteratorOutOfLineOps<Int32> : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: }
// CHECK: struct InputOutputIterator : UnsafeCxxInputIterator {
// CHECK: struct InputOutputIterator : UnsafeCxxMutableInputIterator {
// CHECK: func successor() -> InputOutputIterator
// CHECK: var pointee: Int32
// CHECK: typealias Pointee = Int32
// CHECK: }
// CHECK: struct MutableRACIterator : UnsafeCxxRandomAccessIterator, UnsafeCxxMutableInputIterator {
// CHECK: func successor() -> MutableRACIterator
// CHECK: var pointee: Int32
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }