diff --git a/SwiftCompilerSources/Sources/Optimizer/ModulePasses/MandatoryPerformanceOptimizations.swift b/SwiftCompilerSources/Sources/Optimizer/ModulePasses/MandatoryPerformanceOptimizations.swift index 5252a637020..efc1ee76dda 100644 --- a/SwiftCompilerSources/Sources/Optimizer/ModulePasses/MandatoryPerformanceOptimizations.swift +++ b/SwiftCompilerSources/Sources/Optimizer/ModulePasses/MandatoryPerformanceOptimizations.swift @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +import AST import SIL /// Performs mandatory optimizations for performance-annotated functions, and global @@ -39,11 +40,6 @@ let mandatoryPerformanceOptimizations = ModulePass(name: "mandatory-performance- } optimizeFunctionsTopDown(using: &worklist, moduleContext) - - if moduleContext.options.enableEmbeddedSwift { - // Print errors for generic functions in vtables, which is not allowed in embedded Swift. - checkVTablesForGenericFunctions(moduleContext) - } } private func optimizeFunctionsTopDown(using worklist: inout FunctionWorklist, @@ -131,9 +127,27 @@ private func optimize(function: Function, _ context: FunctionPassContext, _ modu case let initExRef as InitExistentialRefInst: if context.options.enableEmbeddedSwift { - specializeWitnessTables(for: initExRef, moduleContext, &worklist) + for c in initExRef.conformances where c.isConcrete { + specializeWitnessTable(for: c, moduleContext) { + worklist.addWitnessMethods(of: $0) + } + } } + case let bi as BuiltinInst: + switch bi.id { + case .BuildOrdinaryTaskExecutorRef, + .BuildOrdinarySerialExecutorRef, + .BuildComplexEqualitySerialExecutorRef: + specializeWitnessTable(for: bi.substitutionMap.conformances[0], moduleContext) { + worklist.addWitnessMethods(of: $0) + } + + default: + break + } + + // We need to de-virtualize deinits of non-copyable types to be able to specialize the deinitializers. case let destroyValue as DestroyValueInst: if !devirtualizeDeinits(of: destroyValue, simplifyCtxt) { @@ -282,47 +296,6 @@ private func shouldInline(apply: FullApplySite, callee: Function, alreadyInlined return false } -private func specializeWitnessTables(for initExRef: InitExistentialRefInst, _ context: ModulePassContext, - _ worklist: inout FunctionWorklist) -{ - for c in initExRef.conformances where c.isConcrete { - let conformance = c.isInherited ? c.inheritedConformance : c - let origWitnessTable = context.lookupWitnessTable(for: conformance) - if conformance.isSpecialized { - if origWitnessTable == nil { - specializeWitnessTable(forConformance: conformance, errorLocation: initExRef.location, context) { - worklist.addWitnessMethods(of: $0) - } - } - } else if let origWitnessTable { - checkForGenericMethods(in: origWitnessTable, errorLocation: initExRef.location, context) - } - } -} - -private func checkForGenericMethods(in witnessTable: WitnessTable, - errorLocation: Location, - _ context: ModulePassContext) -{ - for entry in witnessTable.entries { - if case .method(let requirement, let witness) = entry, - let witness, - witness.isGeneric - { - context.diagnosticEngine.diagnose(.cannot_specialize_witness_method, requirement, at: errorLocation) - return - } - } -} - -private func checkVTablesForGenericFunctions(_ context: ModulePassContext) { - for vTable in context.vTables where !vTable.class.isGenericAtAnyLevel { - for entry in vTable.entries where entry.implementation.isGeneric { - context.diagnosticEngine.diagnose(.non_final_generic_class_function, at: entry.methodDecl.location) - } - } -} - private extension FullApplySite { func resultIsUsedInGlobalInitialization() -> SmallProjectionPath? { guard parentFunction.isGlobalInitOnceFunction, diff --git a/SwiftCompilerSources/Sources/Optimizer/PassManager/ModulePassContext.swift b/SwiftCompilerSources/Sources/Optimizer/PassManager/ModulePassContext.swift index 7a1e9f39354..4f14579a9cc 100644 --- a/SwiftCompilerSources/Sources/Optimizer/PassManager/ModulePassContext.swift +++ b/SwiftCompilerSources/Sources/Optimizer/PassManager/ModulePassContext.swift @@ -144,14 +144,14 @@ struct ModulePassContext : Context, CustomStringConvertible { } @discardableResult - func createWitnessTable(entries: [WitnessTable.Entry], + func createSpecializedWitnessTable(entries: [WitnessTable.Entry], conformance: Conformance, linkage: Linkage, serialized: Bool) -> WitnessTable { let bridgedEntries = entries.map { $0.bridged } let bridgedWitnessTable = bridgedEntries.withBridgedArrayRef { - _bridged.createWitnessTable(linkage.bridged, serialized, conformance.bridged, $0) + _bridged.createSpecializedWitnessTable(linkage.bridged, serialized, conformance.bridged, $0) } return WitnessTable(bridged: bridgedWitnessTable) } diff --git a/SwiftCompilerSources/Sources/Optimizer/Utilities/GenericSpecialization.swift b/SwiftCompilerSources/Sources/Optimizer/Utilities/GenericSpecialization.swift index 929489e5b33..0bd2c9c2f3b 100644 --- a/SwiftCompilerSources/Sources/Optimizer/Utilities/GenericSpecialization.swift +++ b/SwiftCompilerSources/Sources/Optimizer/Utilities/GenericSpecialization.swift @@ -79,7 +79,7 @@ private struct VTableSpecializer { } private func specializeEntries(of vTable: VTable, _ notifyNewFunction: (Function) -> ()) -> [VTable.Entry] { - return vTable.entries.compactMap { entry in + return vTable.entries.map { entry in if !entry.implementation.isGeneric { return entry } @@ -91,8 +91,7 @@ private struct VTableSpecializer { context.loadFunction(function: entry.implementation, loadCalleesRecursively: true), let specializedMethod = context.specialize(function: entry.implementation, for: methodSubs) else { - context.diagnosticEngine.diagnose(.non_final_generic_class_function, at: entry.methodDecl.location) - return nil + return entry } notifyNewFunction(specializedMethod) @@ -106,16 +105,29 @@ private struct VTableSpecializer { } } -func specializeWitnessTable(forConformance conformance: Conformance, - errorLocation: Location, +/// Specializes a witness table of `conformance` for the concrete type of the conformance. +func specializeWitnessTable(for conformance: Conformance, _ context: ModulePassContext, _ notifyNewWitnessTable: (WitnessTable) -> ()) { - let genericConformance = conformance.genericConformance - guard let witnessTable = context.lookupWitnessTable(for: genericConformance) else { + if let existingSpecialization = context.lookupWitnessTable(for: conformance), + existingSpecialization.isSpecialized + { + return + } + + let baseConf = conformance.isInherited ? conformance.inheritedConformance: conformance + if !baseConf.isSpecialized { + var visited = Set() + specializeDefaultMethods(for: conformance, visited: &visited, context, notifyNewWitnessTable) + return + } + + guard let witnessTable = context.lookupWitnessTable(for: baseConf.genericConformance) else { fatalError("no witness table found") } assert(witnessTable.isDefinition, "No witness table available") + let substitutions = baseConf.specializedSubstitutions let newEntries = witnessTable.entries.map { origEntry in switch origEntry { @@ -125,13 +137,14 @@ func specializeWitnessTable(forConformance conformance: Conformance, guard let origMethod = witness else { return origEntry } - let methodSubs = conformance.specializedSubstitutions.getMethodSubstitutions(for: origMethod) + let methodSubs = substitutions.getMethodSubstitutions(for: origMethod, + // Generic self types need to be handled specially (see `getMethodSubstitutions`) + selfType: origMethod.hasGenericSelf(context) ? conformance.type.canonical : nil) guard !methodSubs.conformances.contains(where: {!$0.isValid}), context.loadFunction(function: origMethod, loadCalleesRecursively: true), let specializedMethod = context.specialize(function: origMethod, for: methodSubs) else { - context.diagnosticEngine.diagnose(.cannot_specialize_witness_method, requirement, at: errorLocation) return origEntry } return .method(requirement: requirement, witness: specializedMethod) @@ -139,7 +152,7 @@ func specializeWitnessTable(forConformance conformance: Conformance, let baseConf = context.getSpecializedConformance(of: witness, for: conformance.type, substitutions: conformance.specializedSubstitutions) - specializeWitnessTable(forConformance: baseConf, errorLocation: errorLocation, context, notifyNewWitnessTable) + specializeWitnessTable(for: baseConf, context, notifyNewWitnessTable) return .baseProtocol(requirement: requirement, witness: baseConf) case .associatedType(let requirement, let witness): let substType = witness.subst(with: conformance.specializedSubstitutions) @@ -150,15 +163,104 @@ func specializeWitnessTable(forConformance conformance: Conformance, let concreteAssociateConf = conformance.getAssociatedConformance(ofAssociatedType: requirement.rawType, to: assocConf.protocol) if concreteAssociateConf.isSpecialized { - specializeWitnessTable(forConformance: concreteAssociateConf, - errorLocation: errorLocation, - context, notifyNewWitnessTable) + specializeWitnessTable(for: concreteAssociateConf, context, notifyNewWitnessTable) } return .associatedConformance(requirement: requirement, witness: concreteAssociateConf) } } - let newWT = context.createWitnessTable(entries: newEntries,conformance: conformance, - linkage: .shared, serialized: false) + let newWT = context.createSpecializedWitnessTable(entries: newEntries,conformance: conformance, + linkage: .shared, serialized: false) notifyNewWitnessTable(newWT) } + +/// Specializes the default methods of a non-generic witness table. +/// Default implementations (in protocol extentions) of non-generic protocol methods have a generic +/// self argument. Specialize such methods with the concrete type. Note that it is important to also +/// specialize inherited conformances so that the concrete self type is correct, even for derived classes. +private func specializeDefaultMethods(for conformance: Conformance, + visited: inout Set, + _ context: ModulePassContext, + _ notifyNewWitnessTable: (WitnessTable) -> ()) +{ + // Avoid infinite recursion, which may happen if an associated conformance is the conformance itself. + guard visited.insert(conformance).inserted, + let witnessTable = context.lookupWitnessTable(for: conformance.rootConformance) + else { + return + } + + assert(witnessTable.isDefinition, "No witness table available") + + var specialized = false + + let newEntries = witnessTable.entries.map { origEntry in + switch origEntry { + case .invalid: + return WitnessTable.Entry.invalid + case .method(let requirement, let witness): + guard let origMethod = witness, + // Is it a generic method where only self is generic (= a default witness method)? + origMethod.isGeneric, origMethod.isNonGenericWitnessMethod(context) + else { + return origEntry + } + // Replace the generic self type with the concrete type. + let methodSubs = SubstitutionMap(genericSignature: origMethod.genericSignature, + replacementTypes: [conformance.type]) + + guard !methodSubs.conformances.contains(where: {!$0.isValid}), + context.loadFunction(function: origMethod, loadCalleesRecursively: true), + let specializedMethod = context.specialize(function: origMethod, for: methodSubs) else + { + return origEntry + } + specialized = true + return .method(requirement: requirement, witness: specializedMethod) + case .baseProtocol(_, let witness): + specializeDefaultMethods(for: witness, visited: &visited, context, notifyNewWitnessTable) + return origEntry + case .associatedType: + return origEntry + case .associatedConformance(_, let assocConf): + specializeDefaultMethods(for: assocConf, visited: &visited, context, notifyNewWitnessTable) + return origEntry + } + } + // If the witness table does not contain any default methods, there is no need to create a + // specialized witness table. + if specialized { + let newWT = context.createSpecializedWitnessTable(entries: newEntries,conformance: conformance, + linkage: .shared, serialized: false) + notifyNewWitnessTable(newWT) + } +} + +private extension Function { + // True, if this is a non-generic method which might have a generic self argument. + // Default implementations (in protocol extentions) of non-generic protocol methods have a generic + // self argument. + func isNonGenericWitnessMethod(_ context: some Context) -> Bool { + switch loweredFunctionType.invocationGenericSignatureOfFunction.genericParameters.count { + case 0: + return true + case 1: + return hasGenericSelf(context) + default: + return false + } + } + + // True, if the self argument is a generic parameter. + func hasGenericSelf(_ context: some Context) -> Bool { + let convention = FunctionConvention(for: loweredFunctionType, + hasLoweredAddresses: context.moduleHasLoweredAddresses) + if convention.hasSelfParameter, + let selfParam = convention.parameters.last, + selfParam.type.isGenericTypeParameter + { + return true + } + return false + } +} diff --git a/include/swift/SILOptimizer/OptimizerBridging.h b/include/swift/SILOptimizer/OptimizerBridging.h index 6b65a9601cb..cdfbe3df3b1 100644 --- a/include/swift/SILOptimizer/OptimizerBridging.h +++ b/include/swift/SILOptimizer/OptimizerBridging.h @@ -328,7 +328,7 @@ struct BridgedPassContext { BridgedSubstitutionMap substitutions) const; SWIFT_IMPORT_UNSAFE BRIDGED_INLINE OptionalBridgedWitnessTable lookupWitnessTable(BridgedConformance conformance) const; - SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedWitnessTable createWitnessTable(BridgedLinkage linkage, + SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedWitnessTable createSpecializedWitnessTable(BridgedLinkage linkage, bool serialized, BridgedConformance conformance, BridgedArrayRef bridgedEntries) const; diff --git a/include/swift/SILOptimizer/OptimizerBridgingImpl.h b/include/swift/SILOptimizer/OptimizerBridgingImpl.h index a4e2ab7434c..d9f0068c05d 100644 --- a/include/swift/SILOptimizer/OptimizerBridgingImpl.h +++ b/include/swift/SILOptimizer/OptimizerBridgingImpl.h @@ -445,7 +445,7 @@ OptionalBridgedWitnessTable BridgedPassContext::lookupWitnessTable(BridgedConfor return {mod->lookUpWitnessTable(ref.getConcrete())}; } -BridgedWitnessTable BridgedPassContext::createWitnessTable(BridgedLinkage linkage, +BridgedWitnessTable BridgedPassContext::createSpecializedWitnessTable(BridgedLinkage linkage, bool serialized, BridgedConformance conformance, BridgedArrayRef bridgedEntries) const { diff --git a/test/embedded/existential-default-method.swift b/test/embedded/existential-default-method.swift new file mode 100644 index 00000000000..7d24545002d --- /dev/null +++ b/test/embedded/existential-default-method.swift @@ -0,0 +1,101 @@ +// RUN: %target-run-simple-swift(-enable-experimental-feature Embedded -parse-as-library -wmo) | %FileCheck %s +// RUN: %target-run-simple-swift(-enable-experimental-feature Embedded -parse-as-library -wmo -O) | %FileCheck %s +// RUN: %target-run-simple-swift(-enable-experimental-feature Embedded -parse-as-library -wmo -Osize) | %FileCheck %s + +// REQUIRES: executable_test +// REQUIRES: optimized_stdlib +// REQUIRES: swift_feature_Embedded + +// Simple case + +public protocol ProtocolWithDefaultMethod: AnyObject { + func getInt() -> Int +} + +extension ProtocolWithDefaultMethod { + public func getInt() -> Int { + return 42 + } +} + +public class Class: ProtocolWithDefaultMethod { +} + +public class GenClass: ProtocolWithDefaultMethod { +} + +func test(existential: any ProtocolWithDefaultMethod) { + print(existential.getInt()) +} + +// Test that we specialize for the correct derived class + +class C { + class func g() -> Int { + return 1 + } +} + +class D: C { + override class func g() -> Int { + return 2 + } +} + +protocol P: AnyObject { + static func g() -> Int + func test() -> Int +} + +extension P { + func test() -> Int { + Self.g() + } +} + +extension C: P {} + +func createDerived() -> P { + return D() +} + +// Test that we don't end up in an infinite recursion loop + +public protocol RecursiveProto: AnyObject { + associatedtype T: RecursiveProto + func getInt() -> Int + func getT() -> T +} + +extension RecursiveProto { + public func getInt() -> Int { + return 27 + } +} + +public class RecursiveClass: RecursiveProto { + public typealias T = RecursiveClass + public func getT() -> RecursiveClass { + return self + } +} + +func testRecursive(existential: any RecursiveProto) { + print(existential.getT().getInt()) +} + + +@main +struct Main { + static func main() { + // CHECK: 42 + test(existential: Class()) + // CHECK: 42 + test(existential: GenClass()) + // CHECK: 2 + print(createDerived().test()) + // CHECK: 27 + testRecursive(existential: RecursiveClass()) + } +} +