[cxx-interop] Conform to UnsafeCxxContiguousIterator based on iterator_concept nested type

3a200dee has a logic bug where we tried to conform C++ iterator types to `UnsafeCxxContiguousIterator` protocol based on their nested type called `iterator_category`. The C++20 standard says we should rely on `iterator_concept` instead.

https://en.cppreference.com/w/cpp/iterator/iterator_tags#Iterator_concept

Despite what the name suggests, we are not actually using C++ concepts in this change.

rdar://137877849
This commit is contained in:
Egor Zhdan
2024-10-16 16:03:38 +01:00
parent 844d103f0d
commit 34f6cd3f1a
3 changed files with 117 additions and 36 deletions

View File

@@ -125,18 +125,12 @@ lookupNestedClangTypeDecl(const clang::CXXRecordDecl *clangDecl,
static clang::TypeDecl *
getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
clang::IdentifierInfo *iteratorCategoryDeclName =
&clangDecl->getASTContext().Idents.get("iterator_category");
auto iteratorCategories = clangDecl->lookup(iteratorCategoryDeclName);
// If this is a templated typedef, Clang might have instantiated several
// equivalent typedef decls. If they aren't equivalent, Clang has already
// complained about this. Let's assume that they are equivalent. (see
// filterNonConflictingPreviousTypedefDecls in clang/Sema/SemaDecl.cpp)
if (iteratorCategories.empty())
return nullptr;
auto iteratorCategory = iteratorCategories.front();
return lookupNestedClangTypeDecl(clangDecl, "iterator_category");
}
return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory);
static clang::TypeDecl *
getIteratorConceptDecl(const clang::CXXRecordDecl *clangDecl) {
return lookupNestedClangTypeDecl(clangDecl, "iterator_concept");
}
static ValueDecl *lookupOperator(NominalTypeDecl *decl, Identifier id,
@@ -435,35 +429,40 @@ void swift::conformToCxxIteratorIfNeeded(
if (!iteratorCategory)
return;
auto unwrapUnderlyingTypeDecl =
[](clang::TypeDecl *typeDecl) -> clang::CXXRecordDecl * {
clang::CXXRecordDecl *underlyingDecl = nullptr;
if (auto typedefDecl = dyn_cast<clang::TypedefNameDecl>(typeDecl)) {
auto type = typedefDecl->getUnderlyingType();
underlyingDecl = type->getAsCXXRecordDecl();
} else {
underlyingDecl = dyn_cast<clang::CXXRecordDecl>(typeDecl);
}
if (underlyingDecl) {
underlyingDecl = underlyingDecl->getDefinition();
}
return underlyingDecl;
};
// If `iterator_category` is a typedef or a using-decl, retrieve the
// underlying struct decl.
clang::CXXRecordDecl *underlyingCategoryDecl = nullptr;
if (auto typedefDecl = dyn_cast<clang::TypedefNameDecl>(iteratorCategory)) {
auto type = typedefDecl->getUnderlyingType();
underlyingCategoryDecl = type->getAsCXXRecordDecl();
} else {
underlyingCategoryDecl = dyn_cast<clang::CXXRecordDecl>(iteratorCategory);
}
if (underlyingCategoryDecl) {
underlyingCategoryDecl = underlyingCategoryDecl->getDefinition();
}
auto underlyingCategoryDecl = unwrapUnderlyingTypeDecl(iteratorCategory);
if (!underlyingCategoryDecl)
return;
auto isIteratorCategoryDecl = [&](const clang::CXXRecordDecl *base,
auto isIteratorTagDecl = [&](const clang::CXXRecordDecl *base,
StringRef tag) {
return base->isInStdNamespace() && base->getIdentifier() &&
base->getName() == tag;
};
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "input_iterator_tag");
return isIteratorTagDecl(base, "input_iterator_tag");
};
auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "random_access_iterator_tag");
return isIteratorTagDecl(base, "random_access_iterator_tag");
};
auto isContiguousIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "contiguous_iterator_tag"); // C++20
return isIteratorTagDecl(base, "contiguous_iterator_tag"); // C++20
};
// Traverse all transitive bases of `underlyingDecl` to check if
@@ -471,7 +470,6 @@ void swift::conformToCxxIteratorIfNeeded(
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
bool isRandomAccessIterator =
isRandomAccessIteratorDecl(underlyingCategoryDecl);
bool isContiguousIterator = isContiguousIteratorDecl(underlyingCategoryDecl);
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
if (isInputIteratorDecl(base)) {
isInputIterator = true;
@@ -479,11 +477,6 @@ void swift::conformToCxxIteratorIfNeeded(
if (isRandomAccessIteratorDecl(base)) {
isRandomAccessIterator = true;
isInputIterator = true;
}
if (isContiguousIteratorDecl(base)) {
isContiguousIterator = true;
isRandomAccessIterator = true;
isInputIterator = true;
return false;
}
return true;
@@ -492,6 +485,27 @@ void swift::conformToCxxIteratorIfNeeded(
if (!isInputIterator)
return;
bool isContiguousIterator = false;
// In C++20, `std::contiguous_iterator_tag` is specified as a type called
// `iterator_concept`. It is not possible to detect a contiguous iterator
// based on its `iterator_category`. The type might not have an
// `iterator_concept` defined.
if (auto iteratorConcept = getIteratorConceptDecl(clangDecl)) {
if (auto underlyingConceptDecl =
unwrapUnderlyingTypeDecl(iteratorConcept)) {
isContiguousIterator = isContiguousIteratorDecl(underlyingConceptDecl);
if (!isContiguousIterator)
underlyingConceptDecl->forallBases(
[&](const clang::CXXRecordDecl *base) {
if (isContiguousIteratorDecl(base)) {
isContiguousIterator = true;
return false;
}
return true;
});
}
}
// Check if present: `var pointee: Pointee { get }`
auto pointeeId = ctx.getIdentifier("pointee");
auto pointee = lookupDirectSingleWithoutExtensions<VarDecl>(decl, pointeeId);

View File

@@ -348,7 +348,8 @@ private:
const int *value;
public:
using iterator_category = std::contiguous_iterator_tag;
using iterator_category = std::random_access_iterator_tag;
using iterator_concept = std::contiguous_iterator_tag;
using value_type = int;
using pointer = int *;
using reference = const int &;
@@ -403,7 +404,8 @@ private:
public:
struct CustomTag : std::contiguous_iterator_tag {};
using iterator_category = CustomTag;
using iterator_category = std::random_access_iterator_tag;
using iterator_concept = CustomTag;
using value_type = int;
using pointer = int *;
using reference = const int &;
@@ -458,7 +460,8 @@ private:
int *value;
public:
using iterator_category = std::contiguous_iterator_tag;
using iterator_category = std::random_access_iterator_tag;
using iterator_concept = std::contiguous_iterator_tag;
using value_type = int;
using pointer = int *;
using reference = const int &;
@@ -507,6 +510,63 @@ public:
return value != other.value;
}
};
/// This is actually just a random access iterator
struct HasNoContiguousIteratorConcept {
private:
const int *value;
public:
using iterator_category = std::contiguous_iterator_tag;
// no iterator_concept
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;
HasNoContiguousIteratorConcept(const int *value) : value(value) {}
HasNoContiguousIteratorConcept(const HasNoContiguousIteratorConcept &other) =
default;
const int &operator*() const { return *value; }
HasNoContiguousIteratorConcept &operator++() {
value++;
return *this;
}
HasNoContiguousIteratorConcept operator++(int) {
auto tmp = HasNoContiguousIteratorConcept(value);
value++;
return tmp;
}
void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
HasNoContiguousIteratorConcept operator+(difference_type v) const {
return HasNoContiguousIteratorConcept(value + v);
}
HasNoContiguousIteratorConcept operator-(difference_type v) const {
return HasNoContiguousIteratorConcept(value - v);
}
friend HasNoContiguousIteratorConcept
operator+(difference_type v, const HasNoContiguousIteratorConcept &it) {
return it + v;
}
int operator-(const HasNoContiguousIteratorConcept &other) const {
return value - other.value;
}
bool operator<(const HasNoContiguousIteratorConcept &other) const {
return value < other.value;
}
bool operator==(const HasNoContiguousIteratorConcept &other) const {
return value == other.value;
}
bool operator!=(const HasNoContiguousIteratorConcept &other) const {
return value != other.value;
}
};
#endif
// MARK: Types that are not actually iterators

View File

@@ -26,3 +26,10 @@
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }
// CHECK: struct HasNoContiguousIteratorConcept : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: func successor() -> HasNoContiguousIteratorConcept
// CHECK: var pointee: Int32
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }