MandatoryPerformanceOptimizations: support default methods for class existentials

For example:
```
protocol P: AnyObject {
  func foo()
}
extension P {
  func foo() {}
}
class C: P {}

let e: any P = C()
```

Such default methods are SILGen'd with a generic self argument. Therefore we need to specialize such witness methods, even if the conforming type is not generic.

rdar://145855851
This commit is contained in:
Erik Eckstein
2025-04-17 16:06:10 +02:00
parent 7a8a50a2b3
commit d222cf20f1
6 changed files with 242 additions and 66 deletions

View File

@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
import AST
import SIL
/// Performs mandatory optimizations for performance-annotated functions, and global
@@ -39,11 +40,6 @@ let mandatoryPerformanceOptimizations = ModulePass(name: "mandatory-performance-
}
optimizeFunctionsTopDown(using: &worklist, moduleContext)
if moduleContext.options.enableEmbeddedSwift {
// Print errors for generic functions in vtables, which is not allowed in embedded Swift.
checkVTablesForGenericFunctions(moduleContext)
}
}
private func optimizeFunctionsTopDown(using worklist: inout FunctionWorklist,
@@ -131,9 +127,27 @@ private func optimize(function: Function, _ context: FunctionPassContext, _ modu
case let initExRef as InitExistentialRefInst:
if context.options.enableEmbeddedSwift {
specializeWitnessTables(for: initExRef, moduleContext, &worklist)
for c in initExRef.conformances where c.isConcrete {
specializeWitnessTable(for: c, moduleContext) {
worklist.addWitnessMethods(of: $0)
}
}
}
case let bi as BuiltinInst:
switch bi.id {
case .BuildOrdinaryTaskExecutorRef,
.BuildOrdinarySerialExecutorRef,
.BuildComplexEqualitySerialExecutorRef:
specializeWitnessTable(for: bi.substitutionMap.conformances[0], moduleContext) {
worklist.addWitnessMethods(of: $0)
}
default:
break
}
// We need to de-virtualize deinits of non-copyable types to be able to specialize the deinitializers.
case let destroyValue as DestroyValueInst:
if !devirtualizeDeinits(of: destroyValue, simplifyCtxt) {
@@ -282,47 +296,6 @@ private func shouldInline(apply: FullApplySite, callee: Function, alreadyInlined
return false
}
private func specializeWitnessTables(for initExRef: InitExistentialRefInst, _ context: ModulePassContext,
_ worklist: inout FunctionWorklist)
{
for c in initExRef.conformances where c.isConcrete {
let conformance = c.isInherited ? c.inheritedConformance : c
let origWitnessTable = context.lookupWitnessTable(for: conformance)
if conformance.isSpecialized {
if origWitnessTable == nil {
specializeWitnessTable(forConformance: conformance, errorLocation: initExRef.location, context) {
worklist.addWitnessMethods(of: $0)
}
}
} else if let origWitnessTable {
checkForGenericMethods(in: origWitnessTable, errorLocation: initExRef.location, context)
}
}
}
private func checkForGenericMethods(in witnessTable: WitnessTable,
errorLocation: Location,
_ context: ModulePassContext)
{
for entry in witnessTable.entries {
if case .method(let requirement, let witness) = entry,
let witness,
witness.isGeneric
{
context.diagnosticEngine.diagnose(.cannot_specialize_witness_method, requirement, at: errorLocation)
return
}
}
}
private func checkVTablesForGenericFunctions(_ context: ModulePassContext) {
for vTable in context.vTables where !vTable.class.isGenericAtAnyLevel {
for entry in vTable.entries where entry.implementation.isGeneric {
context.diagnosticEngine.diagnose(.non_final_generic_class_function, at: entry.methodDecl.location)
}
}
}
private extension FullApplySite {
func resultIsUsedInGlobalInitialization() -> SmallProjectionPath? {
guard parentFunction.isGlobalInitOnceFunction,

View File

@@ -144,14 +144,14 @@ struct ModulePassContext : Context, CustomStringConvertible {
}
@discardableResult
func createWitnessTable(entries: [WitnessTable.Entry],
func createSpecializedWitnessTable(entries: [WitnessTable.Entry],
conformance: Conformance,
linkage: Linkage,
serialized: Bool) -> WitnessTable
{
let bridgedEntries = entries.map { $0.bridged }
let bridgedWitnessTable = bridgedEntries.withBridgedArrayRef {
_bridged.createWitnessTable(linkage.bridged, serialized, conformance.bridged, $0)
_bridged.createSpecializedWitnessTable(linkage.bridged, serialized, conformance.bridged, $0)
}
return WitnessTable(bridged: bridgedWitnessTable)
}

View File

@@ -79,7 +79,7 @@ private struct VTableSpecializer {
}
private func specializeEntries(of vTable: VTable, _ notifyNewFunction: (Function) -> ()) -> [VTable.Entry] {
return vTable.entries.compactMap { entry in
return vTable.entries.map { entry in
if !entry.implementation.isGeneric {
return entry
}
@@ -91,8 +91,7 @@ private struct VTableSpecializer {
context.loadFunction(function: entry.implementation, loadCalleesRecursively: true),
let specializedMethod = context.specialize(function: entry.implementation, for: methodSubs) else
{
context.diagnosticEngine.diagnose(.non_final_generic_class_function, at: entry.methodDecl.location)
return nil
return entry
}
notifyNewFunction(specializedMethod)
@@ -106,16 +105,29 @@ private struct VTableSpecializer {
}
}
func specializeWitnessTable(forConformance conformance: Conformance,
errorLocation: Location,
/// Specializes a witness table of `conformance` for the concrete type of the conformance.
func specializeWitnessTable(for conformance: Conformance,
_ context: ModulePassContext,
_ notifyNewWitnessTable: (WitnessTable) -> ())
{
let genericConformance = conformance.genericConformance
guard let witnessTable = context.lookupWitnessTable(for: genericConformance) else {
if let existingSpecialization = context.lookupWitnessTable(for: conformance),
existingSpecialization.isSpecialized
{
return
}
let baseConf = conformance.isInherited ? conformance.inheritedConformance: conformance
if !baseConf.isSpecialized {
var visited = Set<Conformance>()
specializeDefaultMethods(for: conformance, visited: &visited, context, notifyNewWitnessTable)
return
}
guard let witnessTable = context.lookupWitnessTable(for: baseConf.genericConformance) else {
fatalError("no witness table found")
}
assert(witnessTable.isDefinition, "No witness table available")
let substitutions = baseConf.specializedSubstitutions
let newEntries = witnessTable.entries.map { origEntry in
switch origEntry {
@@ -125,13 +137,14 @@ func specializeWitnessTable(forConformance conformance: Conformance,
guard let origMethod = witness else {
return origEntry
}
let methodSubs = conformance.specializedSubstitutions.getMethodSubstitutions(for: origMethod)
let methodSubs = substitutions.getMethodSubstitutions(for: origMethod,
// Generic self types need to be handled specially (see `getMethodSubstitutions`)
selfType: origMethod.hasGenericSelf(context) ? conformance.type.canonical : nil)
guard !methodSubs.conformances.contains(where: {!$0.isValid}),
context.loadFunction(function: origMethod, loadCalleesRecursively: true),
let specializedMethod = context.specialize(function: origMethod, for: methodSubs) else
{
context.diagnosticEngine.diagnose(.cannot_specialize_witness_method, requirement, at: errorLocation)
return origEntry
}
return .method(requirement: requirement, witness: specializedMethod)
@@ -139,7 +152,7 @@ func specializeWitnessTable(forConformance conformance: Conformance,
let baseConf = context.getSpecializedConformance(of: witness,
for: conformance.type,
substitutions: conformance.specializedSubstitutions)
specializeWitnessTable(forConformance: baseConf, errorLocation: errorLocation, context, notifyNewWitnessTable)
specializeWitnessTable(for: baseConf, context, notifyNewWitnessTable)
return .baseProtocol(requirement: requirement, witness: baseConf)
case .associatedType(let requirement, let witness):
let substType = witness.subst(with: conformance.specializedSubstitutions)
@@ -150,15 +163,104 @@ func specializeWitnessTable(forConformance conformance: Conformance,
let concreteAssociateConf = conformance.getAssociatedConformance(ofAssociatedType: requirement.rawType,
to: assocConf.protocol)
if concreteAssociateConf.isSpecialized {
specializeWitnessTable(forConformance: concreteAssociateConf,
errorLocation: errorLocation,
context, notifyNewWitnessTable)
specializeWitnessTable(for: concreteAssociateConf, context, notifyNewWitnessTable)
}
return .associatedConformance(requirement: requirement,
witness: concreteAssociateConf)
}
}
let newWT = context.createWitnessTable(entries: newEntries,conformance: conformance,
linkage: .shared, serialized: false)
let newWT = context.createSpecializedWitnessTable(entries: newEntries,conformance: conformance,
linkage: .shared, serialized: false)
notifyNewWitnessTable(newWT)
}
/// Specializes the default methods of a non-generic witness table.
/// Default implementations (in protocol extentions) of non-generic protocol methods have a generic
/// self argument. Specialize such methods with the concrete type. Note that it is important to also
/// specialize inherited conformances so that the concrete self type is correct, even for derived classes.
private func specializeDefaultMethods(for conformance: Conformance,
visited: inout Set<Conformance>,
_ context: ModulePassContext,
_ notifyNewWitnessTable: (WitnessTable) -> ())
{
// Avoid infinite recursion, which may happen if an associated conformance is the conformance itself.
guard visited.insert(conformance).inserted,
let witnessTable = context.lookupWitnessTable(for: conformance.rootConformance)
else {
return
}
assert(witnessTable.isDefinition, "No witness table available")
var specialized = false
let newEntries = witnessTable.entries.map { origEntry in
switch origEntry {
case .invalid:
return WitnessTable.Entry.invalid
case .method(let requirement, let witness):
guard let origMethod = witness,
// Is it a generic method where only self is generic (= a default witness method)?
origMethod.isGeneric, origMethod.isNonGenericWitnessMethod(context)
else {
return origEntry
}
// Replace the generic self type with the concrete type.
let methodSubs = SubstitutionMap(genericSignature: origMethod.genericSignature,
replacementTypes: [conformance.type])
guard !methodSubs.conformances.contains(where: {!$0.isValid}),
context.loadFunction(function: origMethod, loadCalleesRecursively: true),
let specializedMethod = context.specialize(function: origMethod, for: methodSubs) else
{
return origEntry
}
specialized = true
return .method(requirement: requirement, witness: specializedMethod)
case .baseProtocol(_, let witness):
specializeDefaultMethods(for: witness, visited: &visited, context, notifyNewWitnessTable)
return origEntry
case .associatedType:
return origEntry
case .associatedConformance(_, let assocConf):
specializeDefaultMethods(for: assocConf, visited: &visited, context, notifyNewWitnessTable)
return origEntry
}
}
// If the witness table does not contain any default methods, there is no need to create a
// specialized witness table.
if specialized {
let newWT = context.createSpecializedWitnessTable(entries: newEntries,conformance: conformance,
linkage: .shared, serialized: false)
notifyNewWitnessTable(newWT)
}
}
private extension Function {
// True, if this is a non-generic method which might have a generic self argument.
// Default implementations (in protocol extentions) of non-generic protocol methods have a generic
// self argument.
func isNonGenericWitnessMethod(_ context: some Context) -> Bool {
switch loweredFunctionType.invocationGenericSignatureOfFunction.genericParameters.count {
case 0:
return true
case 1:
return hasGenericSelf(context)
default:
return false
}
}
// True, if the self argument is a generic parameter.
func hasGenericSelf(_ context: some Context) -> Bool {
let convention = FunctionConvention(for: loweredFunctionType,
hasLoweredAddresses: context.moduleHasLoweredAddresses)
if convention.hasSelfParameter,
let selfParam = convention.parameters.last,
selfParam.type.isGenericTypeParameter
{
return true
}
return false
}
}

View File

@@ -328,7 +328,7 @@ struct BridgedPassContext {
BridgedSubstitutionMap substitutions) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE
OptionalBridgedWitnessTable lookupWitnessTable(BridgedConformance conformance) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedWitnessTable createWitnessTable(BridgedLinkage linkage,
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedWitnessTable createSpecializedWitnessTable(BridgedLinkage linkage,
bool serialized,
BridgedConformance conformance,
BridgedArrayRef bridgedEntries) const;

View File

@@ -445,7 +445,7 @@ OptionalBridgedWitnessTable BridgedPassContext::lookupWitnessTable(BridgedConfor
return {mod->lookUpWitnessTable(ref.getConcrete())};
}
BridgedWitnessTable BridgedPassContext::createWitnessTable(BridgedLinkage linkage,
BridgedWitnessTable BridgedPassContext::createSpecializedWitnessTable(BridgedLinkage linkage,
bool serialized,
BridgedConformance conformance,
BridgedArrayRef bridgedEntries) const {

View File

@@ -0,0 +1,101 @@
// RUN: %target-run-simple-swift(-enable-experimental-feature Embedded -parse-as-library -wmo) | %FileCheck %s
// RUN: %target-run-simple-swift(-enable-experimental-feature Embedded -parse-as-library -wmo -O) | %FileCheck %s
// RUN: %target-run-simple-swift(-enable-experimental-feature Embedded -parse-as-library -wmo -Osize) | %FileCheck %s
// REQUIRES: executable_test
// REQUIRES: optimized_stdlib
// REQUIRES: swift_feature_Embedded
// Simple case
public protocol ProtocolWithDefaultMethod: AnyObject {
func getInt() -> Int
}
extension ProtocolWithDefaultMethod {
public func getInt() -> Int {
return 42
}
}
public class Class: ProtocolWithDefaultMethod {
}
public class GenClass<T>: ProtocolWithDefaultMethod {
}
func test(existential: any ProtocolWithDefaultMethod) {
print(existential.getInt())
}
// Test that we specialize for the correct derived class
class C {
class func g() -> Int {
return 1
}
}
class D: C {
override class func g() -> Int {
return 2
}
}
protocol P: AnyObject {
static func g() -> Int
func test() -> Int
}
extension P {
func test() -> Int {
Self.g()
}
}
extension C: P {}
func createDerived() -> P {
return D()
}
// Test that we don't end up in an infinite recursion loop
public protocol RecursiveProto: AnyObject {
associatedtype T: RecursiveProto
func getInt() -> Int
func getT() -> T
}
extension RecursiveProto {
public func getInt() -> Int {
return 27
}
}
public class RecursiveClass: RecursiveProto {
public typealias T = RecursiveClass
public func getT() -> RecursiveClass {
return self
}
}
func testRecursive(existential: any RecursiveProto) {
print(existential.getT().getInt())
}
@main
struct Main {
static func main() {
// CHECK: 42
test(existential: Class())
// CHECK: 42
test(existential: GenClass<Int>())
// CHECK: 2
print(createDerived().test())
// CHECK: 27
testRecursive(existential: RecursiveClass())
}
}