[cxx-interop] Allow retain/release operations to be methods

Some foreign reference types such as IUnknown define retain/release operations as methods of the type.

Previously Swift only supported retain/release operations as standalone functions.

The syntax for member functions would be `SWIFT_SHARED_REFERENCE(.doRetain, .doRelease)`.

rdar://160696723
(cherry picked from commit e78ce6165f)
This commit is contained in:
Egor Zhdan
2025-08-22 18:18:14 +01:00
parent 9016636714
commit c5c33d6c5a
11 changed files with 324 additions and 15 deletions

View File

@@ -265,6 +265,9 @@ ERROR(foreign_reference_types_release_non_void_return_type, none,
ERROR(foreign_reference_types_retain_release_not_a_function_decl, none, ERROR(foreign_reference_types_retain_release_not_a_function_decl, none,
"specified %select{retain|release}0 function '%1' is not a function", "specified %select{retain|release}0 function '%1' is not a function",
(bool, StringRef)) (bool, StringRef))
ERROR(foreign_reference_types_retain_release_not_an_instance_function, none,
"specified %select{retain|release}0 function '%1' is a static function; expected an instance function",
(bool, StringRef))
ERROR(conforms_to_missing_dot, none, ERROR(conforms_to_missing_dot, none,
"expected module name and protocol name separated by '.' in protocol " "expected module name and protocol name separated by '.' in protocol "
"conformance; '%0' is invalid", "conformance; '%0' is invalid",

View File

@@ -7876,10 +7876,27 @@ getRefParentDecls(const clang::RecordDecl *decl, ASTContext &ctx,
} }
llvm::SmallVector<ValueDecl *, 1> llvm::SmallVector<ValueDecl *, 1>
importer::getValueDeclsForName( importer::getValueDeclsForName(NominalTypeDecl *decl, StringRef name) {
const clang::Decl *decl, ASTContext &ctx, StringRef name) { // If the name is empty, don't try to find any decls.
if (name.empty())
return {};
auto &ctx = decl->getASTContext();
auto clangDecl = decl->getClangDecl();
llvm::SmallVector<ValueDecl *, 1> results; llvm::SmallVector<ValueDecl *, 1> results;
auto *clangMod = decl->getOwningModule();
if (name.starts_with(".")) {
// Look for a member of decl instead of a global.
StringRef memberName = name.drop_front(1);
if (memberName.empty())
return {};
auto declName = DeclName(ctx.getIdentifier(memberName));
auto allResults = evaluateOrDefault(
ctx.evaluator, ClangRecordMemberLookup({decl, declName}), {});
return SmallVector<ValueDecl *, 1>(allResults.begin(), allResults.end());
}
auto *clangMod = clangDecl->getOwningModule();
if (clangMod && clangMod->isSubModule()) if (clangMod && clangMod->isSubModule())
clangMod = clangMod->getTopLevelModule(); clangMod = clangMod->getTopLevelModule();
if (clangMod) { if (clangMod) {
@@ -8487,7 +8504,7 @@ CustomRefCountingOperationResult CustomRefCountingOperation::evaluate(
return {CustomRefCountingOperationResult::immortal, nullptr, name}; return {CustomRefCountingOperationResult::immortal, nullptr, name};
llvm::SmallVector<ValueDecl *, 1> results = llvm::SmallVector<ValueDecl *, 1> results =
getValueDeclsForName(swiftDecl->getClangDecl(), ctx, name); getValueDeclsForName(const_cast<ClassDecl*>(swiftDecl), name);
if (results.size() == 1) if (results.size() == 1)
return {CustomRefCountingOperationResult::foundOperation, results.front(), return {CustomRefCountingOperationResult::foundOperation, results.front(),
name}; name};

View File

@@ -2747,6 +2747,7 @@ namespace {
enum class RetainReleaseOperationKind { enum class RetainReleaseOperationKind {
notAfunction, notAfunction,
notAnInstanceFunction,
invalidReturnType, invalidReturnType,
invalidParameters, invalidParameters,
valid valid
@@ -2760,17 +2761,32 @@ namespace {
if (!operationFn) if (!operationFn)
return RetainReleaseOperationKind::notAfunction; return RetainReleaseOperationKind::notAfunction;
if (operationFn->getParameters()->size() != 1) if (operationFn->isStatic())
return RetainReleaseOperationKind::invalidParameters; return RetainReleaseOperationKind::notAnInstanceFunction;
Type paramType = if (operationFn->isInstanceMember()) {
operationFn->getParameters()->get(0)->getInterfaceType(); if (operationFn->getParameters()->size() != 0)
// Unwrap if paramType is an OptionalType return RetainReleaseOperationKind::invalidParameters;
if (Type optionalType = paramType->getOptionalObjectType()) { } else {
paramType = optionalType; if (operationFn->getParameters()->size() != 1)
return RetainReleaseOperationKind::invalidParameters;
} }
swift::NominalTypeDecl *paramDecl = paramType->getAnyNominal(); Type paramType;
NominalTypeDecl *paramDecl = nullptr;
if (!operationFn->isInstanceMember()) {
paramType =
operationFn->getParameters()->get(0)->getInterfaceType();
// Unwrap if paramType is an OptionalType
if (Type optionalType = paramType->getOptionalObjectType()) {
paramType = optionalType;
}
paramDecl = paramType->getAnyNominal();
} else {
paramDecl = cast<NominalTypeDecl>(operationFn->getParent());
paramType = paramDecl->getDeclaredInterfaceType();
}
// The return type should be void (for release functions), or void // The return type should be void (for release functions), or void
// or the parameter type (for retain functions). // or the parameter type (for retain functions).
@@ -2855,6 +2871,12 @@ namespace {
diag::foreign_reference_types_retain_release_not_a_function_decl, diag::foreign_reference_types_retain_release_not_a_function_decl,
false, retainOperation.name); false, retainOperation.name);
break; break;
case RetainReleaseOperationKind::notAnInstanceFunction:
Impl.diagnose(
loc,
diag::foreign_reference_types_retain_release_not_an_instance_function,
false, retainOperation.name);
break;
case RetainReleaseOperationKind::invalidReturnType: case RetainReleaseOperationKind::invalidReturnType:
Impl.diagnose( Impl.diagnose(
loc, loc,
@@ -2920,6 +2942,12 @@ namespace {
diag::foreign_reference_types_retain_release_not_a_function_decl, diag::foreign_reference_types_retain_release_not_a_function_decl,
true, releaseOperation.name); true, releaseOperation.name);
break; break;
case RetainReleaseOperationKind::notAnInstanceFunction:
Impl.diagnose(
loc,
diag::foreign_reference_types_retain_release_not_an_instance_function,
true, releaseOperation.name);
break;
case RetainReleaseOperationKind::invalidReturnType: case RetainReleaseOperationKind::invalidReturnType:
Impl.diagnose( Impl.diagnose(
loc, loc,

View File

@@ -2151,7 +2151,7 @@ ImportedType findOptionSetEnum(clang::QualType type,
/// ///
/// The name we're looking for is the Swift name. /// The name we're looking for is the Swift name.
llvm::SmallVector<ValueDecl *, 1> llvm::SmallVector<ValueDecl *, 1>
getValueDeclsForName(const clang::Decl *decl, ASTContext &ctx, StringRef name); getValueDeclsForName(NominalTypeDecl* decl, StringRef name);
} // end namespace importer } // end namespace importer
} // end namespace swift } // end namespace swift

View File

@@ -2767,8 +2767,7 @@ FuncDecl *SwiftDeclSynthesizer::findExplicitDestroy(
if (!destroyFuncName.consume_front("destroy:")) if (!destroyFuncName.consume_front("destroy:"))
continue; continue;
auto decls = getValueDeclsForName( auto decls = getValueDeclsForName(nominal, destroyFuncName);
clangType, nominal->getASTContext(), destroyFuncName);
for (auto decl : decls) { for (auto decl : decls) {
auto func = dyn_cast<FuncDecl>(decl); auto func = dyn_cast<FuncDecl>(decl);
if (!func) if (!func)

View File

@@ -1722,6 +1722,11 @@ void IRGenFunction::emitBlockRelease(llvm::Value *value) {
void IRGenFunction::emitForeignReferenceTypeLifetimeOperation( void IRGenFunction::emitForeignReferenceTypeLifetimeOperation(
ValueDecl *fn, llvm::Value *value, bool needsNullCheck) { ValueDecl *fn, llvm::Value *value, bool needsNullCheck) {
if (auto originalDecl = fn->getASTContext()
.getClangModuleLoader()
->getOriginalForClonedMember(fn))
fn = originalDecl;
assert(fn->getClangDecl() && isa<clang::FunctionDecl>(fn->getClangDecl())); assert(fn->getClangDecl() && isa<clang::FunctionDecl>(fn->getClangDecl()));
auto clangFn = cast<clang::FunctionDecl>(fn->getClangDecl()); auto clangFn = cast<clang::FunctionDecl>(fn->getClangDecl());

View File

@@ -0,0 +1,121 @@
#include <swift/bridging>
struct RefCountedBox {
int value;
int refCount = 1;
RefCountedBox(int value) : value(value) {}
void doRetain() { refCount++; }
void doRelease() { refCount--; }
} SWIFT_SHARED_REFERENCE(.doRetain, .doRelease);
struct DerivedRefCountedBox : RefCountedBox {
int secondValue = 1;
DerivedRefCountedBox(int value, int secondValue)
: RefCountedBox(value), secondValue(secondValue) {}
};
// MARK: Retain in a base type, release in derived
struct BaseHasRetain {
mutable int refCount = 1;
void doRetainInBase() const { refCount++; }
};
struct DerivedHasRelease : BaseHasRetain {
int value;
DerivedHasRelease(int value) : value(value) {}
void doRelease() const { refCount--; }
} SWIFT_SHARED_REFERENCE(.doRetainInBase, .doRelease);
// MARK: Retain in a base type, release in templated derived
template <typename T>
struct TemplatedDerivedHasRelease : BaseHasRetain {
T value;
TemplatedDerivedHasRelease(T value) : value(value) {}
void doReleaseTemplated() const { refCount--; }
} SWIFT_SHARED_REFERENCE(.doRetainInBase, .doReleaseTemplated);
using TemplatedDerivedHasReleaseFloat = TemplatedDerivedHasRelease<float>;
using TemplatedDerivedHasReleaseInt = TemplatedDerivedHasRelease<int>;
// MARK: Retain/release in CRTP base type
template <typename Derived>
struct CRTPBase {
mutable int refCount = 1;
void crtpRetain() const { refCount++; }
void crtpRelease() const { refCount--; }
} SWIFT_SHARED_REFERENCE(.crtpRetain, .crtpRelease);
struct CRTPDerived : CRTPBase<CRTPDerived> {
int value;
CRTPDerived(int value) : value(value) {}
};
// MARK: Virtual retain and release
struct VirtualRetainRelease {
int value;
mutable int refCount = 1;
VirtualRetainRelease(int value) : value(value) {}
virtual void doRetainVirtual() const { refCount++; }
virtual void doReleaseVirtual() const { refCount--; }
virtual ~VirtualRetainRelease() = default;
} SWIFT_SHARED_REFERENCE(.doRetainVirtual, .doReleaseVirtual);
struct DerivedVirtualRetainRelease : VirtualRetainRelease {
DerivedVirtualRetainRelease(int value) : VirtualRetainRelease(value) {}
mutable bool calledDerived = false;
void doRetainVirtual() const override { refCount++; calledDerived = true; }
void doReleaseVirtual() const override { refCount--; }
};
// MARK: Pure virtual retain and release
struct PureVirtualRetainRelease {
int value;
mutable int refCount = 1;
PureVirtualRetainRelease(int value) : value(value) {}
virtual void doRetainPure() const = 0;
virtual void doReleasePure() const = 0;
virtual ~PureVirtualRetainRelease() = default;
} SWIFT_SHARED_REFERENCE(.doRetainPure, .doReleasePure);
struct DerivedPureVirtualRetainRelease : PureVirtualRetainRelease {
mutable int refCount = 1;
DerivedPureVirtualRetainRelease(int value) : PureVirtualRetainRelease(value) {}
void doRetainPure() const override { refCount++; }
void doReleasePure() const override { refCount--; }
};
// MARK: Static retain/release
#ifdef INCORRECT
struct StaticRetainRelease {
// expected-error@-1 {{specified retain function '.staticRetain' is a static function; expected an instance function}}
// expected-error@-2 {{specified release function '.staticRelease' is a static function; expected an instance function}}
int value;
int refCount = 1;
StaticRetainRelease(int value) : value(value) {}
static void staticRetain(StaticRetainRelease* o) { o->refCount++; }
static void staticRelease(StaticRetainRelease* o) { o->refCount--; }
} SWIFT_SHARED_REFERENCE(.staticRetain, .staticRelease);
struct DerivedStaticRetainRelease : StaticRetainRelease {
// expected-error@-1 {{cannot find retain function '.staticRetain' for reference type 'DerivedStaticRetainRelease'}}
// expected-error@-2 {{cannot find release function '.staticRelease' for reference type 'DerivedStaticRetainRelease'}}
int secondValue = 1;
DerivedStaticRetainRelease(int value, int secondValue)
: StaticRetainRelease(value), secondValue(secondValue) {}
};
#endif

View File

@@ -44,6 +44,11 @@ module ReferenceCountedObjCProperty {
export * export *
} }
module LifetimeOperationMethods {
header "lifetime-operation-methods.h"
requires cplusplus
}
module MemberLayout { module MemberLayout {
header "member-layout.h" header "member-layout.h"
requires cplusplus requires cplusplus

View File

@@ -0,0 +1,49 @@
// RUN: %target-swift-ide-test -print-module -cxx-interoperability-mode=upcoming-swift -I %swift_src_root/lib/ClangImporter/SwiftBridging -module-to-print=LifetimeOperationMethods -I %S/Inputs -source-filename=x | %FileCheck %s
// CHECK: class RefCountedBox {
// CHECK: func doRetain()
// CHECK: func doRelease()
// CHECK: }
// CHECK: class DerivedRefCountedBox {
// CHECK: func doRetain()
// CHECK: func doRelease()
// CHECK: }
// CHECK: class DerivedHasRelease {
// CHECK: func doRelease()
// CHECK: func doRetainInBase()
// CHECK: }
// CHECK: class TemplatedDerivedHasRelease<CFloat> {
// CHECK: var value: Float
// CHECK: func doReleaseTemplated()
// CHECK: func doRetainInBase()
// CHECK: }
// CHECK: class TemplatedDerivedHasRelease<CInt> {
// CHECK: var value: Int32
// CHECK: func doReleaseTemplated()
// CHECK: func doRetainInBase()
// CHECK: }
// CHECK: class CRTPDerived {
// CHECK: var value: Int32
// CHECK: }
// CHECK: class VirtualRetainRelease {
// CHECK: func doRetainVirtual()
// CHECK: func doReleaseVirtual()
// CHECK: }
// CHECK: class DerivedVirtualRetainRelease {
// CHECK: func doRetainVirtual()
// CHECK: func doReleaseVirtual()
// CHECK: }
// CHECK: class PureVirtualRetainRelease {
// CHECK: func doRetainPure()
// CHECK: func doReleasePure()
// CHECK: }
// CHECK: class DerivedPureVirtualRetainRelease {
// CHECK: func doRetainPure()
// CHECK: func doReleasePure()
// CHECK: var refCount: Int32
// CHECK: }

View File

@@ -0,0 +1,6 @@
// RUN: %target-typecheck-verify-swift -Xcc -DINCORRECT -I %S%{fs-sep}Inputs -I %swift_src_root/lib/ClangImporter/SwiftBridging -verify-additional-file %S%{fs-sep}Inputs%{fs-sep}lifetime-operation-methods.h -cxx-interoperability-mode=upcoming-swift -disable-availability-checking
import LifetimeOperationMethods
let _ = StaticRetainRelease(123)
let _ = DerivedStaticRetainRelease(123, 456)

View File

@@ -0,0 +1,76 @@
// RUN: %target-run-simple-swift(-I %S/Inputs -cxx-interoperability-mode=upcoming-swift -I %swift_src_root/lib/ClangImporter/SwiftBridging -Xfrontend -disable-availability-checking)
// Temporarily disable when running with an older runtime (rdar://128681137)
// UNSUPPORTED: use_os_stdlib
// UNSUPPORTED: back_deployment_runtime
import StdlibUnittest
import LifetimeOperationMethods
var LifetimeMethodsTestSuite = TestSuite("Lifetime operations that are instance methods")
LifetimeMethodsTestSuite.test("retain/release methods") {
let a = RefCountedBox(123)
expectEqual(a.value, 123)
expectTrue(a.refCount > 0)
expectTrue(a.refCount < 10) // optimizations would affect the exact number
}
LifetimeMethodsTestSuite.test("retain/release methods from base type") {
let a = DerivedRefCountedBox(321, 456)
expectEqual(a.value, 321)
expectEqual(a.secondValue, 456)
expectTrue(a.refCount > 0)
expectTrue(a.refCount < 10) // optimizations would affect the exact number
a.secondValue = 789
expectEqual(a.secondValue, 789)
}
LifetimeMethodsTestSuite.test("retain in base type, release in derived type") {
let a = DerivedHasRelease(321)
expectEqual(a.value, 321)
expectTrue(a.refCount > 0)
expectTrue(a.refCount < 10) // optimizations would affect the exact number
}
LifetimeMethodsTestSuite.test("retain in base type, release in derived templated type") {
let a = TemplatedDerivedHasReleaseInt(456)
expectEqual(a.value, 456)
expectTrue(a.refCount > 0)
expectTrue(a.refCount < 10) // optimizations would affect the exact number
let b = TemplatedDerivedHasReleaseFloat(5.66)
expectEqual(b.value, 5.66)
}
LifetimeMethodsTestSuite.test("CRTP") {
let a = CRTPDerived(789)
expectEqual(a.value, 789)
expectTrue(a.refCount > 0)
expectTrue(a.refCount < 10) // optimizations would affect the exact number
}
LifetimeMethodsTestSuite.test("virtual retain/release") {
let a = VirtualRetainRelease(456)
expectEqual(a.value, 456)
expectTrue(a.refCount > 0)
expectTrue(a.refCount < 10) // optimizations would affect the exact number
}
LifetimeMethodsTestSuite.test("overridden virtual retain/release") {
let a = DerivedVirtualRetainRelease(456)
expectEqual(a.value, 456)
expectTrue(a.calledDerived)
expectTrue(a.refCount > 0)
expectTrue(a.refCount < 10) // optimizations would affect the exact number
}
LifetimeMethodsTestSuite.test("overridden pure virtual retain/release") {
let a = DerivedPureVirtualRetainRelease(789)
expectEqual(a.value, 789)
expectTrue(a.refCount > 0)
expectTrue(a.refCount < 10) // optimizations would affect the exact number
}
runAllTests()