[Autodiff] Adds logic to generate specialized functions in the closure-spec pass

This commit is contained in:
Kshitij
2024-04-03 16:00:27 -07:00
parent 15cab3a19f
commit ab751d57ab
21 changed files with 1279 additions and 268 deletions

View File

@@ -122,20 +122,6 @@ let autodiffClosureSpecialization = FunctionPass(name: "autodiff-closure-special
if !function.isAutodiffVJP {
return
}
print("Specializing closures in function: \(function.name)")
print("===============================================")
var callSites = gatherCallSites(in: function, context)
callSites.forEach { callSite in
print("PartialApply call site: \(callSite.applySite)")
print("Passed in closures: ")
for index in callSite.closureArgDescriptors.indices {
var closureArgDescriptor = callSite.closureArgDescriptors[index]
print("\(index+1). \(closureArgDescriptor.closureInfo.closure)")
}
}
print("\n")
}
// =========== Top-level functions ========== //
@@ -188,6 +174,45 @@ private func gatherCallSites(in caller: Function, _ context: FunctionPassContext
return callSiteMap.callSites
}
private func getOrCreateSpecializedFunction(basedOn callSite: CallSite, _ context: FunctionPassContext)
-> (function: Function, alreadyExists: Bool)
{
let specializedFunctionName = callSite.specializedCalleeName(context)
if let specializedFunction = context.lookupFunction(name: specializedFunctionName) {
return (specializedFunction, true)
}
let applySiteCallee = callSite.applyCallee
let specializedParameters = applySiteCallee.convention.getSpecializedParameters(basedOn: callSite)
let createFn = { (functionPassContext: FunctionPassContext) in
specializedFunctionName._withBridgedStringRef { nameRef in
let bridgedParamInfos = specializedParameters.map { $0._bridged }
return bridgedParamInfos.withUnsafeBufferPointer { paramBuf in
functionPassContext
._bridged
.ClosureSpecializer_createEmptyFunctionWithSpecializedSignature(nameRef, paramBuf.baseAddress, paramBuf.count,
applySiteCallee.bridged,
applySiteCallee.isSerialized)
.function
}
}
}
let buildFn = { (emptySpecializedFunction, functionPassContext) in
let closureSpecCloner = SpecializationCloner(emptySpecializedFunction: emptySpecializedFunction, functionPassContext)
closureSpecCloner.cloneAndSpecializeFunctionBody(using: callSite)
}
let specializedFunction = context.createAndBuildSpecializedFunction(createFn: createFn, buildFn: buildFn)
return (specializedFunction, false)
}
private func rewriteApplyInstruction(in caller: Function, _ context: FunctionPassContext) {
fatalError("Not implemented")
}
// ===================== Utility functions and extensions ===================== //
private func updateCallSites(for rootClosure: SingleValueInstruction, in callSiteMap: inout CallSiteMap,
@@ -304,9 +329,9 @@ private func handleNonApplies(for rootClosure: SingleValueInstruction,
case let pai as PartialApplyInst:
if !pai.isPullbackInResultOfAutodiffVJP,
pai.isPartialApplyOfReabstractionThunk,
pai.isSupportedClosure,
pai.arguments[0].type.isNoEscapeFunction,
pai.isPartialApplyOfThunk,
// Argument must be a closure
pai.arguments[0].type.isThickFunction
{
rootClosureConversionsAndReabstractions.pushIfNotVisited(contentsOf: pai.uses)
@@ -410,29 +435,18 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap:
continue
}
// We currently only support copying intermediate reabstraction closures if the final closure is ultimately passed
// trivially.
let closureType = use.value.type
let isClosurePassedTrivially = closureType.isNoEscapeFunction && closureType.isThickFunction
// Mark the converted/reabstracted closures as used.
if haveUsedReabstraction {
markConvertedAndReabstractedClosuresAsUsed(rootClosure: rootClosure, convertedAndReabstractedClosure: use.value,
convertedAndReabstractedClosures: &convertedAndReabstractedClosures)
if !isClosurePassedTrivially {
continue
}
}
let onlyHaveThinToThickClosure = rootClosure is ThinToThickFunctionInst && !haveUsedReabstraction
guard let closureParamInfo = pai.operandConventions[parameter: use.index] else {
fatalError("While handling apply uses, parameter info not found for operand: \(use)!")
}
if (closureParamInfo.convention.isGuaranteed || isClosurePassedTrivially)
&& !onlyHaveThinToThickClosure
// If we are going to need to release the copied over closure, we must make sure that we understand all the exit
// blocks, i.e., they terminate with an instruction that clearly indicates whether to release the copied over
// closure or leak it.
if closureParamInfo.convention.isGuaranteed,
!onlyHaveThinToThickClosure,
!callee.blocks.allSatisfy({ $0.isReachableExitBlock || $0.terminator is UnreachableInst })
{
continue
}
@@ -465,6 +479,12 @@ private func handleApplies(for rootClosure: SingleValueInstruction, callSiteMap:
continue
}
// Mark the converted/reabstracted closures as used.
if haveUsedReabstraction {
markConvertedAndReabstractedClosuresAsUsed(rootClosure: rootClosure, convertedAndReabstractedClosure: use.value,
convertedAndReabstractedClosures: &convertedAndReabstractedClosures)
}
if callSiteMap[pai] == nil {
callSiteMap.insert(key: pai, value: CallSite(applySite: pai))
}
@@ -563,6 +583,349 @@ private func markConvertedAndReabstractedClosuresAsUsed(rootClosure: Value, conv
}
}
private extension SpecializationCloner {
func cloneAndSpecializeFunctionBody(using callSite: CallSite) {
self.cloneEntryBlockArgsWithoutOrigClosures(usingOrigCalleeAt: callSite)
let (allSpecializedEntryBlockArgs, closureArgIndexToAllClonedReleasableClosures) = cloneAllClosures(at: callSite)
self.cloneFunctionBody(from: callSite.applyCallee, entryBlockArgs: allSpecializedEntryBlockArgs)
self.insertCleanupCodeForClonedReleasableClosures(
from: callSite, closureArgIndexToAllClonedReleasableClosures: closureArgIndexToAllClonedReleasableClosures)
}
private func cloneEntryBlockArgsWithoutOrigClosures(usingOrigCalleeAt callSite: CallSite) {
let originalEntryBlock = callSite.applyCallee.entryBlock
let clonedFunction = self.cloned
let clonedEntryBlock = self.entryBlock
originalEntryBlock.arguments
.enumerated()
.filter { index, _ in !callSite.hasClosureArg(at: index) }
.forEach { _, arg in
let clonedEntryBlockArgType = arg.type.getLoweredType(in: clonedFunction)
let clonedEntryBlockArg = clonedEntryBlock.addFunctionArgument(type: clonedEntryBlockArgType, self.context)
clonedEntryBlockArg.copyFlags(from: arg as! FunctionArgument)
}
}
/// Clones all closures, originally passed to the callee at the given callSite, into the specialized function.
///
/// Returns the following -
/// - allSpecializedEntryBlockArgs: Complete list of entry block arguments for the specialized function. This includes
/// the original arguments to the function (minus the closure arguments) and the arguments representing the values
/// originally captured by the skipped closure arguments.
///
/// - closureArgIndexToAllClonedReleasableClosures: Mapping from a closure's argument index at `callSite` to the list
/// of corresponding releasable closures cloned into the specialized function. We have a "list" because we clone
/// "closure chains", which consist of a "root" closure and its conversions/reabstractions. This map is used to
/// generate cleanup code for the cloned closures in the specialized function.
private func cloneAllClosures(at callSite: CallSite)
-> (allSpecializedEntryBlockArgs: [Value],
closureArgIndexToAllClonedReleasableClosures: [Int: [SingleValueInstruction]])
{
func entryBlockArgsWithOrigClosuresSkipped() -> [Value?] {
var clonedNonClosureEntryBlockArgs = self.entryBlock.arguments.makeIterator()
return callSite.applyCallee
.entryBlock
.arguments
.enumerated()
.reduce(into: []) { result, origArgTuple in
let (index, _) = origArgTuple
if !callSite.hasClosureArg(at: index) {
result.append(clonedNonClosureEntryBlockArgs.next())
} else {
result.append(Optional.none)
}
}
}
var entryBlockArgs: [Value?] = entryBlockArgsWithOrigClosuresSkipped()
var closureArgIndexToAllClonedReleasableClosures: [Int: [SingleValueInstruction]] = [:]
for closureArgDesc in callSite.closureArgDescriptors {
let (finalClonedReabstractedClosure, allClonedReleasableClosures) =
self.cloneClosureChain(representedBy: closureArgDesc, at: callSite)
entryBlockArgs[closureArgDesc.closureArgIndex] = finalClonedReabstractedClosure
closureArgIndexToAllClonedReleasableClosures[closureArgDesc.closureArgIndex] = allClonedReleasableClosures
}
return (entryBlockArgs.map { $0! }, closureArgIndexToAllClonedReleasableClosures)
}
private func cloneClosureChain(representedBy closureArgDesc: ClosureArgDescriptor, at callSite: CallSite)
-> (finalClonedReabstractedClosure: SingleValueInstruction, allClonedReleasableClosures: [SingleValueInstruction])
{
let (origToClonedValueMap, capturedArgRange) = self.addEntryBlockArgs(forValuesCapturedBy: closureArgDesc)
let clonedFunction = self.cloned
let clonedEntryBlock = self.entryBlock
let clonedClosureArgs = Array(clonedEntryBlock.arguments[capturedArgRange])
let builder = clonedEntryBlock.instructions.isEmpty
? Builder(atStartOf: clonedFunction, self.context)
: Builder(atEndOf: clonedEntryBlock, location: clonedEntryBlock.instructions.last!.location, self.context)
let clonedRootClosure = builder.cloneRootClosure(representedBy: closureArgDesc, capturedArgs: clonedClosureArgs)
let (finalClonedReabstractedClosure, releasableClonedReabstractedClosures) =
builder.cloneRootClosureReabstractions(rootClosure: closureArgDesc.closure, clonedRootClosure: clonedRootClosure,
reabstractedClosure: callSite
.appliedArgForClosure(
at: closureArgDesc.closureArgIndex)!,
origToClonedValueMap: origToClonedValueMap,
self.context)
let allClonedReleasableClosures = [clonedRootClosure] + releasableClonedReabstractedClosures
return (finalClonedReabstractedClosure, allClonedReleasableClosures)
}
private func addEntryBlockArgs(forValuesCapturedBy closureArgDesc: ClosureArgDescriptor)
-> (origToClonedValueMap: [HashableValue: Value], capturedArgRange: Range<Int>)
{
var origToClonedValueMap: [HashableValue: Value] = [:]
var capturedArgRange = 0..<0
let clonedFunction = self.cloned
let clonedEntryBlock = self.entryBlock
if let capturedArgs = closureArgDesc.arguments {
let capturedArgRangeStart = clonedEntryBlock.arguments.count
for arg in capturedArgs {
let capturedArg = clonedEntryBlock.addFunctionArgument(type: arg.type.getLoweredType(in: clonedFunction),
self.context)
origToClonedValueMap[arg] = capturedArg
}
let capturedArgRangeEnd = clonedEntryBlock.arguments.count
capturedArgRange = capturedArgRangeStart..<capturedArgRangeEnd
}
return (origToClonedValueMap, capturedArgRange)
}
private func insertCleanupCodeForClonedReleasableClosures(from callSite: CallSite,
closureArgIndexToAllClonedReleasableClosures: [Int: [SingleValueInstruction]])
{
for closureArgDesc in callSite.closureArgDescriptors {
let allClonedReleasableClosures = closureArgIndexToAllClonedReleasableClosures[closureArgDesc.closureArgIndex]!
// Insert a `destroy_value`, for all releasable closures, in all reachable exit BBs if the closure was passed as a
// guaranteed parameter or its type was noescape+thick. This is b/c the closure was passed at +0 originally and we
// need to balance the initial increment of the newly created closure(s).
if closureArgDesc.isClosureGuaranteed || closureArgDesc.isClosureTrivialNoEscape,
!allClonedReleasableClosures.isEmpty
{
for exitBlock in closureArgDesc.reachableExitBBs {
let clonedExitBlock = self.getClonedBlock(for: exitBlock)
let terminator = clonedExitBlock.terminator is UnreachableInst
? clonedExitBlock.terminator.previous!
: clonedExitBlock.terminator
let builder = Builder(before: terminator, self.context)
for closure in allClonedReleasableClosures {
if let pai = closure as? PartialApplyInst,
pai.isOnStack
{
builder.destroyPartialApplyOnStack(paiOnStack: pai)
} else{
builder.createDestroyValue(operand: closure)
}
}
}
}
}
}
}
private extension [HashableValue: Value] {
subscript(key: Value) -> Value? {
get {
self[HashableValue(key)]
}
set {
self[HashableValue(key)] = newValue
}
}
}
private extension Builder {
func cloneRootClosure(representedBy closureArgDesc: ClosureArgDescriptor, capturedArgs: [Value])
-> SingleValueInstruction
{
let function = self.createFunctionRef(closureArgDesc.callee)
if let pai = closureArgDesc.closure as? PartialApplyInst {
return self.createPartialApply(forFunction: function, substitutionMap: SubstitutionMap(),
capturedArgs: capturedArgs, calleeConvention: pai.calleeConvention,
hasUnknownResultIsolation: pai.hasUnknownResultIsolation,
isOnStack: pai.isOnStack)
} else {
return self.createThinToThickFunction(thinFunction: function, resultType: closureArgDesc.closure.type)
}
}
func cloneRootClosureReabstractions(rootClosure: Value, clonedRootClosure: Value, reabstractedClosure: Value,
origToClonedValueMap: [HashableValue: Value], _ context: FunctionPassContext)
-> (finalClonedReabstractedClosure: SingleValueInstruction, releasableClonedReabstractedClosures: [PartialApplyInst])
{
func inner(_ rootClosure: Value, _ clonedRootClosure: Value, _ reabstractedClosure: Value,
_ releasableClonedReabstractedClosures: inout [PartialApplyInst],
_ origToClonedValueMap: inout [HashableValue: Value]) -> Value {
switch reabstractedClosure {
case let reabstractedClosure where reabstractedClosure == rootClosure:
origToClonedValueMap[reabstractedClosure] = clonedRootClosure
return clonedRootClosure
case let cvt as ConvertFunctionInst:
let toBeReabstracted = inner(rootClosure, clonedRootClosure, cvt.fromFunction,
&releasableClonedReabstractedClosures, &origToClonedValueMap)
let reabstracted = self.createConvertFunction(originalFunction: toBeReabstracted, resultType: cvt.type,
withoutActuallyEscaping: cvt.withoutActuallyEscaping)
origToClonedValueMap[cvt] = reabstracted
return reabstracted
case let cvt as ConvertEscapeToNoEscapeInst:
let toBeReabstracted = inner(rootClosure, clonedRootClosure, cvt.fromFunction,
&releasableClonedReabstractedClosures, &origToClonedValueMap)
let reabstracted = self.createConvertEscapeToNoEscape(originalFunction: toBeReabstracted, resultType: cvt.type,
isLifetimeGuaranteed: true)
origToClonedValueMap[cvt] = reabstracted
return reabstracted
case let pai as PartialApplyInst:
let toBeReabstracted = inner(rootClosure, clonedRootClosure, pai.arguments[0],
&releasableClonedReabstractedClosures, &origToClonedValueMap)
guard let function = pai.referencedFunction else {
fatalError("Encountered unsupported reabstraction (via partial_apply) of root closure!")
}
let fri = self.createFunctionRef(function)
let reabstracted = self.createPartialApply(forFunction: fri, substitutionMap: SubstitutionMap(),
capturedArgs: [toBeReabstracted],
calleeConvention: pai.calleeConvention,
hasUnknownResultIsolation: pai.hasUnknownResultIsolation,
isOnStack: pai.isOnStack)
releasableClonedReabstractedClosures.append(reabstracted)
origToClonedValueMap[pai] = reabstracted
return reabstracted
case let mdi as MarkDependenceInst:
let toBeReabstracted = inner(rootClosure, clonedRootClosure, mdi.value, &releasableClonedReabstractedClosures,
&origToClonedValueMap)
let base = origToClonedValueMap[mdi.base]!
let reabstracted = self.createMarkDependence(value: toBeReabstracted, base: base, kind: .Escaping)
origToClonedValueMap[mdi] = reabstracted
return reabstracted
default:
fatalError("Encountered unsupported reabstraction of root closure: \(reabstractedClosure)")
}
}
var releasableClonedReabstractedClosures: [PartialApplyInst] = []
var origToClonedValueMap = origToClonedValueMap
let finalClonedReabstractedClosure = inner(rootClosure, clonedRootClosure, reabstractedClosure,
&releasableClonedReabstractedClosures, &origToClonedValueMap)
return (finalClonedReabstractedClosure as! SingleValueInstruction, releasableClonedReabstractedClosures)
}
func destroyPartialApplyOnStack(paiOnStack: PartialApplyInst) {
precondition(paiOnStack.isOnStack, "Function must only be called for `partial_apply`s on stack!")
for arg in paiOnStack.arguments {
self.createDestroyValue(operand: arg)
}
self.createDestroyValue(operand: paiOnStack)
}
}
private extension FunctionConvention {
func getSpecializedParameters(basedOn callSite: CallSite) -> [ParameterInfo] {
let applySiteCallee = callSite.applyCallee
var specializedParamInfoList: [ParameterInfo] = []
// Start by adding all original parameters except for the closure parameters.
let firstParamIndex = applySiteCallee.argumentConventions.firstParameterIndex
for (index, paramInfo) in applySiteCallee.convention.parameters.enumerated() {
let argIndex = index + firstParamIndex
if !callSite.hasClosureArg(at: argIndex) {
specializedParamInfoList.append(paramInfo)
}
}
// Now, append parameters captured by each of the original closure parameter.
//
// Captured parameters are always appended to the function signature. If the argument type of the captured
// parameter in the callee is:
// - direct and trivial, pass the new parameter as Direct_Unowned.
// - direct and non-trivial, pass the new parameter as Direct_Owned.
// - indirect, pass the new parameter using the same parameter convention as in
// the original closure.
for closureArgDesc in callSite.closureArgDescriptors {
if let closure = closureArgDesc.closure as? PartialApplyInst {
let closureCallee = closureArgDesc.callee
let closureCalleeConvention = closureCallee.convention
let unappliedArgumentCount = closure.unappliedArgumentCount - closureCalleeConvention.indirectSILResultCount
let prevCapturedParameters =
closureCalleeConvention
.parameters[unappliedArgumentCount...]
.enumerated()
.map { index, paramInfo in
let argIndexOfParam = closureCallee.argumentConventions.firstParameterIndex + unappliedArgumentCount + index
let argType = closureCallee.argumentTypes[argIndexOfParam]
return paramInfo.withSpecializedConvention(isArgTypeTrivial: argType.isTrivial(in: closureCallee))
}
specializedParamInfoList.append(contentsOf: prevCapturedParameters)
}
}
return specializedParamInfoList
}
}
private extension ParameterInfo {
func withSpecializedConvention(isArgTypeTrivial: Bool) -> Self {
let specializedParamConvention =
if self.hasAllowedIndirectConvForClosureSpec {
self.convention
} else {
isArgTypeTrivial ? ArgumentConvention.directUnowned : ArgumentConvention.directOwned
}
return ParameterInfo(type: self.type, convention: specializedParamConvention, options: self.options,
hasLoweredAddresses: self.hasLoweredAddresses)
}
var hasAllowedIndirectConvForClosureSpec: Bool {
switch convention {
case .indirectInout, .indirectInoutAliasable:
return true
default:
return false
}
}
}
private extension ArgumentConvention {
var isAllowedIndirectConvForClosureSpec: Bool {
switch self {
case .indirectInout, .indirectInoutAliasable:
return true
default:
return false
}
}
}
private extension PartialApplyInst {
/// True, if the closure obtained from this partial_apply is the
/// pullback returned from an autodiff VJP
@@ -579,6 +942,17 @@ private extension PartialApplyInst {
return false
}
var isPartialApplyOfThunk: Bool {
if self.numArguments == 1 || self.numArguments == 2,
let fun = self.referencedFunction,
fun.thunkKind == .reabstractionThunk || fun.thunkKind == .thunk
{
return true
}
return false
}
var hasOnlyInoutIndirectArguments: Bool {
self.argumentOperands
.filter { !$0.value.type.isObject }
@@ -621,6 +995,20 @@ private extension Function {
}
// ===================== Utility Types ===================== //
private enum HashableValue: Hashable {
case Argument(FunctionArgument)
case Instruction(SingleValueInstruction)
init(_ value: Value) {
if let instr = value as? SingleValueInstruction {
self = .Instruction(instr)
} else if let arg = value as? FunctionArgument {
self = .Argument(arg)
} else {
fatalError("Invalid hashable value: \(value)")
}
}
}
private struct OrderedDict<Key: Hashable, Value> {
private var valueIndexDict: [Key: Int] = [:]
@@ -663,7 +1051,6 @@ private extension CallSiteMap {
}
}
/// Represents all the information required to represent a closure in isolation, i.e., outside of a callsite context
/// where the closure may be getting passed as an argument.
///
@@ -685,6 +1072,76 @@ private struct ClosureArgDescriptor {
/// The index of the closure in the callsite's argument list.
let closureArgumentIndex: Int
let parameterInfo: ParameterInfo
var closure: SingleValueInstruction {
closureInfo.closure
}
var isPartialApply: Bool {
closure is PartialApplyInst
}
var isPartialApplyOnStack: Bool {
if let pai = closure as? PartialApplyInst {
return pai.isOnStack
}
return false
}
var callee: Function {
if let pai = closure as? PartialApplyInst {
return pai.referencedFunction!
} else {
return (closure as! ThinToThickFunctionInst).referencedFunction!
}
}
var location: Location {
closure.location
}
var closureArgIndex: Int {
closureArgumentIndex
}
var closureParamInfo: ParameterInfo {
parameterInfo
}
var numArguments: Int {
if let pai = closure as? PartialApplyInst {
return pai.numArguments
} else {
return 0
}
}
var arguments: (some Sequence<Value>)? {
if let pai = closure as? PartialApplyInst {
return pai.arguments
}
return nil as LazyMapSequence<OperandArray, Value>?
}
var isClosureGuaranteed: Bool {
closureParamInfo.convention.isGuaranteed
}
var isClosureConsumed: Bool {
closureParamInfo.convention.isConsumed
}
var isClosureTrivialNoEscape: Bool {
closureParamInfo.type.SILFunctionType_isTrivialNoescape()
}
var parentFunction: Function {
closure.parentFunction
}
var reachableExitBBs: [BasicBlock] {
closure.parentFunction.blocks.filter { $0.isReachableExitBlock }
}
}
/// Represents a callsite containing one or more closure arguments.
@@ -699,6 +1156,54 @@ private struct CallSite {
public mutating func appendClosureArgDescriptor(_ descriptor: ClosureArgDescriptor) {
self.closureArgDescriptors.append(descriptor)
}
var applyCallee: Function {
applySite.referencedFunction!
}
var isCalleeSerialized: Bool {
applyCallee.isSerialized
}
var firstClosureArgDesc: ClosureArgDescriptor? {
closureArgDescriptors.first
}
func hasClosureArg(at index: Int) -> Bool {
closureArgDescriptors.contains { $0.closureArgumentIndex == index }
}
func closureArgDesc(at index: Int) -> ClosureArgDescriptor? {
closureArgDescriptors.first { $0.closureArgumentIndex == index }
}
func closureArg(at index: Int) -> SingleValueInstruction? {
closureArgDesc(at: index)?.closure
}
func appliedArgForClosure(at index: Int) -> Value? {
if let closureArgDesc = closureArgDesc(at: index) {
return applySite.arguments[closureArgDesc.closureArgIndex - applySite.unappliedArgumentCount]
}
return nil
}
func closureCallee(at index: Int) -> Function? {
closureArgDesc(at: index)?.callee
}
func closureLoc(at index: Int) -> Location? {
closureArgDesc(at: index)?.location
}
func specializedCalleeName(_ context: FunctionPassContext) -> String {
let closureArgs = Array(self.closureArgDescriptors.map { $0.closure })
let closureIndices = Array(self.closureArgDescriptors.map { $0.closureArgIndex })
return context.mangle(withClosureArgs: closureArgs, closureArgIndices: closureIndices,
from: applyCallee)
}
}
// ===================== Unit tests ===================== //
@@ -718,3 +1223,15 @@ let gatherCallSitesTest = FunctionTest("closure_specialize_gather_call_sites") {
}
print("\n")
}
let specializedFunctionSignatureAndBodyTest = FunctionTest(
"closure_specialize_specialized_function_signature_and_body") { function, arguments, context in
var callSites = gatherCallSites(in: function, context)
for callSite in callSites {
let (specializedFunction, _) = getOrCreateSpecializedFunction(basedOn: callSite, context)
print("Generated specialized function: \(specializedFunction.name)")
print("\(specializedFunction)\n")
}
}

View File

@@ -353,12 +353,41 @@ struct FunctionPassContext : MutatingContext {
return String(taking: _bridged.mangleOutlinedVariable(function.bridged))
}
func mangle(withClosureArgs closureArgs: [Value], closureArgIndices: [Int], from applySiteCallee: Function) -> String {
closureArgs.withBridgedValues { bridgedClosureArgsRef in
closureArgIndices.withBridgedArrayRef{bridgedClosureArgIndicesRef in
String(taking: _bridged.mangleWithClosureArgs(
bridgedClosureArgsRef,
bridgedClosureArgIndicesRef,
applySiteCallee.bridged
))
}
}
}
func createGlobalVariable(name: String, type: Type, isPrivate: Bool) -> GlobalVariable {
let gv = name._withBridgedStringRef {
_bridged.createGlobalVariable($0, type.bridged, isPrivate)
}
return gv.globalVar
}
/// Utility function that should be used by optimizations that generate new functions or specialized versions of
/// existing functions.
func createAndBuildSpecializedFunction(createFn: (FunctionPassContext) -> Function,
buildFn: (Function, FunctionPassContext) -> ()) -> Function
{
let specializedFunction = createFn(self)
let nestedFunctionPassContext =
FunctionPassContext(_bridged: _bridged.initializeNestedPassContext(specializedFunction.bridged))
defer { _bridged.deinitializedNestedPassContext() }
buildFn(specializedFunction, nestedFunctionPassContext)
return specializedFunction
}
}
struct SimplifyContext : MutatingContext {
@@ -445,6 +474,13 @@ extension Builder {
context.notifyInstructionChanged, context._bridged.asNotificationHandler())
}
/// Creates a builder which inserts instructions into an empty function, using the location of the function itself.
init(atStartOf function: Function, _ context: some MutatingContext) {
context.verifyIsTransforming(function: function)
self.init(insertAt: .atStartOf(function), context.notifyInstructionChanged,
context._bridged.asNotificationHandler())
}
init(staticInitializerOf global: GlobalVariable, _ context: some MutatingContext) {
self.init(insertAt: .staticInitializer(global),
location: Location.artificialUnreachableLocation,
@@ -629,7 +665,6 @@ extension Function {
bridged.setIsPerformanceConstraint(isPerformanceConstraint)
}
func fixStackNesting(_ context: FunctionPassContext) {
context._bridged.fixStackNesting(bridged)
}

View File

@@ -10,6 +10,7 @@ swift_compiler_sources(Optimizer
AddressUtils.swift
BorrowedFromUpdater.swift
BorrowUtils.swift
SpecializationCloner.swift
DiagnosticEngine.swift
Devirtualization.swift
EscapeUtils.swift

View File

@@ -178,6 +178,11 @@ extension Builder {
insertFunc(builder)
}
}
func destroyCapturedArgs(for paiOnStack: PartialApplyInst) {
precondition(paiOnStack.isOnStack, "Function must only be called for `partial_apply`s on stack!")
self.bridged.destroyCapturedArgs(paiOnStack.bridged)
}
}
extension Value {
@@ -398,23 +403,6 @@ extension LoadInst {
}
}
extension PartialApplyInst {
var isPartialApplyOfReabstractionThunk: Bool {
// A partial_apply of a reabstraction thunk either has a single capture
// (a function) or two captures (function and dynamic Self type).
if self.numArguments == 1 || self.numArguments == 2,
let fun = self.referencedFunction,
fun.isReabstractionThunk,
self.arguments[0].type.isFunction,
self.arguments[0].type.isReferenceCounted(in: self.parentFunction) || self.callee.type.isThickFunction
{
return true
}
return false
}
}
extension FunctionPassContext {
/// Returns true if any blocks were removed.
func removeDeadBlocks(in function: Function) -> Bool {

View File

@@ -0,0 +1,57 @@
//===--- SpecializationCloner.swift --------------------------------------------==//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
import OptimizerBridging
import SIL
/// Utility cloner type that can be used by optimizations that generate new functions or specialized versions of
/// existing functions.
struct SpecializationCloner {
private var _context: FunctionPassContext
private var _bridged: BridgedSpecializationCloner
init(emptySpecializedFunction: Function, _ context: FunctionPassContext) {
self._context = context
self._bridged = BridgedSpecializationCloner(emptySpecializedFunction.bridged)
}
public var context: FunctionPassContext {
self._context
}
public var bridged: BridgedSpecializationCloner {
self._bridged
}
public var cloned: Function {
bridged.getCloned().function
}
public var entryBlock: BasicBlock {
if cloned.blocks.isEmpty {
cloned.appendNewBlock(context)
} else {
cloned.entryBlock
}
}
public func getClonedBlock(for originalBlock: BasicBlock) -> BasicBlock {
bridged.getClonedBasicBlock(originalBlock.bridged).block
}
public func cloneFunctionBody(from originalFunction: Function, entryBlockArgs: [Value]) {
entryBlockArgs.withBridgedValues { bridgedEntryBlockArgs in
bridged.cloneFunctionBody(originalFunction.bridged, self.entryBlock.bridged, bridgedEntryBlockArgs)
}
}
}

View File

@@ -165,7 +165,8 @@ public func registerOptimizerTests() {
linearLivenessTest,
parseTestSpecificationTest,
variableIntroducerTest,
gatherCallSitesTest
gatherCallSitesTest,
specializedFunctionSignatureAndBodyTest
)
// Finally register the thunk they all call through.

View File

@@ -69,6 +69,15 @@ final public class FunctionArgument : Argument {
public var resultDependence: LifetimeDependenceConvention? {
parentFunction.argumentConventions[resultDependsOn: index]
}
/// Copies the following flags from `arg`:
/// 1. noImplicitCopy
/// 2. lifetimeAnnotation
/// 3. closureCapture
/// 4. parameterPack
public func copyFlags(from arg: FunctionArgument) {
bridged.copyFlags(arg.bridged)
}
}
public struct Phi {
@@ -445,6 +454,17 @@ public enum ArgumentConvention : CustomStringConvertible {
}
}
public var isConsumed: Bool {
switch self {
case .indirectIn, .directOwned, .packOwned:
return true
case .indirectInGuaranteed, .directGuaranteed, .packGuaranteed,
.indirectInout, .indirectInoutAliasable, .indirectOut,
.packOut, .packInout, .directUnowned:
return false
}
}
public var isExclusiveIndirect: Bool {
switch self {
case .indirectIn,

View File

@@ -122,6 +122,8 @@ public struct InstructionList : CollectionLikeSequence, IteratorProtocol {
public var first: Instruction? { currentInstruction }
public var last: Instruction? { reversed().first }
public func reversed() -> ReverseInstructionList {
if let inst = currentInstruction {
let lastInst = inst.bridged.getLastInstOfParent().instruction

View File

@@ -19,11 +19,12 @@ public struct Builder {
public enum InsertionPoint {
case before(Instruction)
case atEndOf(BasicBlock)
case atStartOf(Function)
case staticInitializer(GlobalVariable)
}
let insertAt: InsertionPoint
let location: Location
let location: Location?
private let notificationHandler: BridgedChangeNotificationHandler
private let notifyNewInstruction: (Instruction) -> ()
@@ -31,13 +32,15 @@ public struct Builder {
switch insertAt {
case .before(let inst):
return BridgedBuilder(insertAt: .beforeInst, insertionObj: inst.bridged.obj,
loc: location.bridged)
loc: location!.bridged)
case .atEndOf(let block):
return BridgedBuilder(insertAt: .endOfBlock, insertionObj: block.bridged.obj,
loc: location.bridged)
loc: location!.bridged)
case .atStartOf(let function):
return BridgedBuilder(insertAt: .startOfFunction, function: function.bridged)
case .staticInitializer(let global):
return BridgedBuilder(insertAt: .intoGlobal, insertionObj: global.bridged.obj,
loc: location.bridged)
loc: location!.bridged)
}
}
@@ -53,9 +56,26 @@ public struct Builder {
return instruction
}
public init(insertAt: InsertionPoint,
_ notifyNewInstruction: @escaping (Instruction) -> (),
_ notificationHandler: BridgedChangeNotificationHandler) {
guard case let .atStartOf(_) = insertAt else {
fatalError("Initializer must only be used to initialize builders that insert at the start of functions.")
}
self.insertAt = insertAt
self.location = nil;
self.notifyNewInstruction = notifyNewInstruction
self.notificationHandler = notificationHandler
}
public init(insertAt: InsertionPoint, location: Location,
_ notifyNewInstruction: @escaping (Instruction) -> (),
_ notificationHandler: BridgedChangeNotificationHandler) {
if case let .atStartOf(_) = insertAt {
fatalError("Initializer must not be used to initialize builders that insert at the start of functions.")
}
self.insertAt = insertAt
self.location = location;
self.notifyNewInstruction = notifyNewInstruction
@@ -147,6 +167,18 @@ public struct Builder {
return notifyNew(endInit.getAs(EndInitLetRefInst.self))
}
@discardableResult
public func createRetainValue(operand: Value) -> RetainValueInst {
let retain = bridged.createRetainValue(operand.bridged)
return notifyNew(retain.getAs(RetainValueInst.self))
}
@discardableResult
public func createReleaseValue(operand: Value) -> ReleaseValueInst {
let release = bridged.createReleaseValue(operand.bridged)
return notifyNew(release.getAs(ReleaseValueInst.self))
}
@discardableResult
public func createStrongRetain(operand: Value) -> StrongRetainInst {
let retain = bridged.createStrongRetain(operand.bridged)
@@ -290,6 +322,20 @@ public struct Builder {
return notifyNew(tttf.getAs(ThinToThickFunctionInst.self))
}
public func createPartialApply(
forFunction function: Value,
substitutionMap: SubstitutionMap,
capturedArgs: [Value],
calleeConvention: ArgumentConvention,
hasUnknownResultIsolation: Bool,
isOnStack: Bool
) -> PartialApplyInst {
return capturedArgs.withBridgedValues { capturedArgsRef in
let pai = bridged.createPartialApply(function.bridged, capturedArgsRef, calleeConvention.bridged, substitutionMap.bridged, hasUnknownResultIsolation, isOnStack)
return notifyNew(pai.getAs(PartialApplyInst.self))
}
}
@discardableResult
public func createSwitchEnum(enum enumVal: Value,
cases: [(Int, BasicBlock)],
@@ -434,4 +480,14 @@ public struct Builder {
let endAccess = bridged.createEndAccess(beginAccess.bridged)
return notifyNew(endAccess.getAs(EndAccessInst.self))
}
public func createConvertFunction(originalFunction: Value, resultType: Type, withoutActuallyEscaping: Bool) -> ConvertFunctionInst {
let convertFunction = bridged.createConvertFunction(originalFunction.bridged, resultType.bridged, withoutActuallyEscaping)
return notifyNew(convertFunction.getAs(ConvertFunctionInst.self))
}
public func createConvertEscapeToNoEscape(originalFunction: Value, resultType: Type, isLifetimeGuaranteed: Bool) -> ConvertEscapeToNoEscapeInst {
let convertFunction = bridged.createConvertEscapeToNoEscape(originalFunction.bridged, resultType.bridged, isLifetimeGuaranteed)
return notifyNew(convertFunction.getAs(ConvertEscapeToNoEscapeInst.self))
}
}

View File

@@ -93,8 +93,6 @@ final public class Function : CustomStringConvertible, HasShortDescription, Hash
public var isAsync: Bool { bridged.isAsync() }
public var isReabstractionThunk: Bool { bridged.isReabstractionThunk() }
/// True if this is a `[global_init]` function.
///
/// Such a function is typically a global addressor which calls the global's
@@ -140,7 +138,7 @@ final public class Function : CustomStringConvertible, HasShortDescription, Hash
case noThunk, thunk, reabstractionThunk, signatureOptimizedThunk
}
var thunkKind: ThunkKind {
public var thunkKind: ThunkKind {
switch bridged.isThunk() {
case .IsNotThunk: return .noThunk
case .IsThunk: return .thunk

View File

@@ -156,6 +156,13 @@ public struct ParameterInfo : CustomStringConvertible {
public let options: UInt8
public let hasLoweredAddresses: Bool
public init(type: BridgedASTType, convention: ArgumentConvention, options: UInt8, hasLoweredAddresses: Bool) {
self.type = type
self.convention = convention
self.options = options
self.hasLoweredAddresses = hasLoweredAddresses
}
/// Is this parameter passed indirectly in SIL? Most formally
/// indirect results can be passed directly in SIL (opaque values
/// mode). This depends on whether the calling function has lowered

View File

@@ -890,6 +890,7 @@ class UnconditionalCheckedCastInst : SingleValueInstruction, UnaryInstruction {
final public
class ConvertFunctionInst : SingleValueInstruction, UnaryInstruction {
public var fromFunction: Value { operand.value }
public var withoutActuallyEscaping: Bool { bridged.ConvertFunctionInst_withoutActuallyEscaping() }
}
final public
@@ -1040,7 +1041,9 @@ final public class PartialApplyInst : SingleValueInstruction, ApplySite {
return arguments.contains { $0.type.containsNoEscapeFunction }
}
public var hasUnknownResultIsolation: Bool { bridged.PartialApplyInst_hasUnknownResultIsolation() }
public var unappliedArgumentCount: Int { bridged.PartialApply_getCalleeArgIndexOfFirstAppliedArg() }
public var calleeConvention: ArgumentConvention { type.bridged.getCalleeConvention().convention }
}
final public class ApplyInst : SingleValueInstruction, FullApplySite {

View File

@@ -98,6 +98,10 @@ public struct Type : CustomStringConvertible, NoReflectionChildren {
public var tupleElements: TupleElementArray { TupleElementArray(type: self) }
public func getLoweredType(in function: Function) -> Type {
function.bridged.getLoweredType(self.bridged).type
}
/// Can only be used if the type is in fact a nominal type (`isNominal` is true).
/// Returns nil if the nominal is a resilient type because in this case the complete list
/// of fields is not known.

View File

@@ -146,22 +146,8 @@ enum class BridgedArgumentConvention {
Pack_Out
};
struct BridgedParameterInfo {
swift::TypeBase * _Nonnull type;
BridgedArgumentConvention convention;
uint8_t options;
BridgedParameterInfo(swift::TypeBase * _Nonnull type, BridgedArgumentConvention convention, uint8_t options) :
type(type), convention(convention), options(options) {}
#ifdef USED_IN_CPP_SOURCE
inline static BridgedArgumentConvention
castToArgumentConvention(swift::ParameterConvention convention) {
return static_cast<BridgedArgumentConvention>(
swift::SILArgumentConvention(convention).Value);
}
swift::ParameterConvention getParameterConvention() const {
static swift::ParameterConvention getParameterConvention(BridgedArgumentConvention convention) {
switch (convention) {
case BridgedArgumentConvention::Indirect_In: return swift::ParameterConvention::Indirect_In;
case BridgedArgumentConvention::Indirect_In_Guaranteed: return swift::ParameterConvention::Indirect_In_Guaranteed;
@@ -177,6 +163,38 @@ struct BridgedParameterInfo {
case BridgedArgumentConvention::Pack_Out: break;
}
llvm_unreachable("invalid parameter convention");
}
static BridgedArgumentConvention getArgumentConvention(swift::ParameterConvention convention) {
switch (convention) {
case swift::ParameterConvention::Indirect_In: return BridgedArgumentConvention::Indirect_In;
case swift::ParameterConvention::Indirect_In_Guaranteed: return BridgedArgumentConvention::Indirect_In_Guaranteed;
case swift::ParameterConvention::Indirect_Inout: return BridgedArgumentConvention::Indirect_Inout;
case swift::ParameterConvention::Indirect_InoutAliasable: return BridgedArgumentConvention::Indirect_InoutAliasable;
case swift::ParameterConvention::Direct_Owned: return BridgedArgumentConvention::Direct_Owned;
case swift::ParameterConvention::Direct_Unowned: return BridgedArgumentConvention::Direct_Unowned;
case swift::ParameterConvention::Direct_Guaranteed: return BridgedArgumentConvention::Direct_Guaranteed;
case swift::ParameterConvention::Pack_Owned: return BridgedArgumentConvention::Pack_Owned;
case swift::ParameterConvention::Pack_Inout: return BridgedArgumentConvention::Pack_Inout;
case swift::ParameterConvention::Pack_Guaranteed: return BridgedArgumentConvention::Pack_Guaranteed;
}
llvm_unreachable("invalid parameter convention");
}
#endif
struct BridgedParameterInfo {
swift::TypeBase * _Nonnull type;
BridgedArgumentConvention convention;
uint8_t options;
BridgedParameterInfo(swift::TypeBase * _Nonnull type, BridgedArgumentConvention convention, uint8_t options) :
type(type), convention(convention), options(options) {}
#ifdef USED_IN_CPP_SOURCE
inline static BridgedArgumentConvention
castToArgumentConvention(swift::ParameterConvention convention) {
return static_cast<BridgedArgumentConvention>(
swift::SILArgumentConvention(convention).Value);
}
BridgedParameterInfo(swift::SILParameterInfo parameterInfo):
@@ -186,7 +204,7 @@ struct BridgedParameterInfo {
{}
swift::SILParameterInfo unbridged() const {
return swift::SILParameterInfo(swift::CanType(type), getParameterConvention(),
return swift::SILParameterInfo(swift::CanType(type), getParameterConvention(convention),
swift::SILParameterInfo::Options(options));
}
#endif
@@ -279,6 +297,8 @@ struct BridgedASTType {
BRIDGED_INLINE bool SILFunctionType_hasSelfParam() const;
BRIDGED_INLINE bool SILFunctionType_isTrivialNoescape() const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE
BridgedYieldInfoArray SILFunctionType_getYields() const;
@@ -377,6 +397,7 @@ struct BridgedType {
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedType
getTupleElementType(SwiftInt idx) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedType getFunctionTypeWithNoEscape(bool withNoEscape) const;
BRIDGED_INLINE BridgedArgumentConvention getCalleeConvention() const;
};
// SIL Bridging
@@ -531,6 +552,22 @@ struct BridgedFunction {
IsSignatureOptimizedThunk
};
enum class Linkage {
Public,
PublicNonABI,
Package,
PackageNonABI,
Hidden,
Shared,
Private,
PublicExternal,
PackageExternal,
HiddenExternal
};
SWIFT_NAME("init(obj:)")
SWIFT_IMPORT_UNSAFE BridgedFunction(SwiftObject obj) : obj(obj) {}
SWIFT_IMPORT_UNSAFE BridgedFunction() {}
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE swift::SILFunction * _Nonnull getFunction() const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedStringRef getName() const;
SWIFT_IMPORT_UNSAFE BridgedOwnedString getDebugDescription() const;
@@ -549,7 +586,6 @@ struct BridgedFunction {
BRIDGED_INLINE bool isAvailableExternally() const;
BRIDGED_INLINE bool isTransparent() const;
BRIDGED_INLINE bool isAsync() const;
BRIDGED_INLINE bool isReabstractionThunk() const;
BRIDGED_INLINE bool isGlobalInitFunction() const;
BRIDGED_INLINE bool isGlobalInitOnceFunction() const;
BRIDGED_INLINE bool isDestructor() const;
@@ -570,6 +606,8 @@ struct BridgedFunction {
BRIDGED_INLINE void setIsPerformanceConstraint(bool isPerfConstraint) const;
BRIDGED_INLINE bool isResilientNominalDecl(BridgedNominalTypeDecl decl) const;
BRIDGED_INLINE BridgedType getLoweredType(BridgedASTType type) const;
BRIDGED_INLINE BridgedType getLoweredType(BridgedType type) const;
BRIDGED_INLINE void setLinkage(Linkage linkage) const;
bool isTrapNoReturn() const;
bool isAutodiffVJP() const;
SwiftInt specializationLevel() const;
@@ -882,6 +920,7 @@ struct BridgedInstruction {
BRIDGED_INLINE SwiftInt ObjectInst_getNumBaseElements() const;
BRIDGED_INLINE SwiftInt PartialApply_getCalleeArgIndexOfFirstAppliedArg() const;
BRIDGED_INLINE bool PartialApplyInst_isOnStack() const;
BRIDGED_INLINE bool PartialApplyInst_hasUnknownResultIsolation() const;
BRIDGED_INLINE bool AllocStackInst_hasDynamicLifetime() const;
BRIDGED_INLINE bool AllocRefInstBase_isObjc() const;
BRIDGED_INLINE bool AllocRefInstBase_canAllocOnStack() const;
@@ -927,6 +966,7 @@ struct BridgedInstruction {
BRIDGED_INLINE SwiftInt ApplySite_getNumArguments() const;
BRIDGED_INLINE bool ApplySite_isCalleeNoReturn() const;
BRIDGED_INLINE SwiftInt FullApplySite_numIndirectResultArguments() const;
BRIDGED_INLINE bool ConvertFunctionInst_withoutActuallyEscaping() const;
// =========================================================================//
// VarDeclInst and DebugVariableInst
@@ -977,6 +1017,7 @@ struct BridgedArgument {
BRIDGED_INLINE bool isReborrow() const;
BRIDGED_INLINE bool hasResultDependsOn() const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedNullableVarDecl getVarDecl() const;
BRIDGED_INLINE void copyFlags(BridgedArgument fromArgument) const;
};
struct OptionalBridgedArgument {
@@ -1117,7 +1158,7 @@ struct OptionalBridgedDefaultWitnessTable {
struct BridgedBuilder{
enum class InsertAt {
beforeInst, endOfBlock, intoGlobal
beforeInst, endOfBlock, startOfFunction, intoGlobal
} insertAt;
SwiftObject insertionObj;
@@ -1132,6 +1173,8 @@ struct BridgedBuilder{
case BridgedBuilder::InsertAt::endOfBlock:
return swift::SILBuilder(BridgedBasicBlock(insertionObj).unbridged(),
loc.getLoc().getScope());
case BridgedBuilder::InsertAt::startOfFunction:
return swift::SILBuilder(BridgedFunction(insertionObj).getFunction()->getEntryBlock());
case BridgedBuilder::InsertAt::intoGlobal:
return swift::SILBuilder(BridgedGlobalVar(insertionObj).getGlobal());
}
@@ -1141,6 +1184,13 @@ struct BridgedBuilder{
}
#endif
SWIFT_NAME("init(insertAt:insertionObj:loc:)")
SWIFT_IMPORT_UNSAFE BridgedBuilder(InsertAt insertAt, SwiftObject insertionObj, BridgedLocation loc):
insertAt(insertAt), insertionObj(insertionObj), loc(loc) {}
SWIFT_NAME("init(insertAt:function:)")
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedBuilder(InsertAt insertAt, BridgedFunction function);
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createBuiltinBinaryFunction(BridgedStringRef name,
BridgedType operandType, BridgedType resultType,
BridgedValueArray arguments) const;
@@ -1165,6 +1215,8 @@ struct BridgedBuilder{
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createBeginDeallocRef(BridgedValue reference,
BridgedValue allocation) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createEndInitLetRef(BridgedValue op) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createRetainValue(BridgedValue op) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createReleaseValue(BridgedValue op) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createStrongRetain(BridgedValue op) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createStrongRelease(BridgedValue op) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createUnownedRetain(BridgedValue op) const;
@@ -1207,6 +1259,12 @@ struct BridgedBuilder{
BridgedType resultType) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createThinToThickFunction(BridgedValue fn,
BridgedType resultType) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createPartialApply(BridgedValue fn,
BridgedValueArray bridgedCapturedArgs,
BridgedArgumentConvention calleeConvention,
BridgedSubstitutionMap bridgedSubstitutionMap = BridgedSubstitutionMap(),
bool hasUnknownIsolation = true,
bool isOnStack = false) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createBranch(BridgedBasicBlock destBlock,
BridgedValueArray arguments) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createUnreachable() const;
@@ -1247,6 +1305,11 @@ struct BridgedBuilder{
BridgedValue value, BridgedValue base, BridgedInstruction::MarkDependenceKind dependenceKind) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createEndAccess(BridgedValue value) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createConvertFunction(BridgedValue originalFunction, BridgedType resultType, bool withoutActuallyEscaping) const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE BridgedInstruction createConvertEscapeToNoEscape(BridgedValue originalFunction, BridgedType resultType, bool isLifetimeGuaranteed) const;
SWIFT_IMPORT_UNSAFE void destroyCapturedArgs(BridgedInstruction partialApply) const;
};
// Passmanager and Context

View File

@@ -153,6 +153,10 @@ bool BridgedASTType::SILFunctionType_hasSelfParam() const {
return unbridged()->castTo<swift::SILFunctionType>()->hasSelfParam();
}
bool BridgedASTType::SILFunctionType_isTrivialNoescape() const {
return unbridged()->castTo<swift::SILFunctionType>()->isTrivialNoEscape();
}
BridgedYieldInfoArray BridgedASTType::SILFunctionType_getYields() const {
return unbridged()->castTo<swift::SILFunctionType>()->getYields();
}
@@ -389,6 +393,11 @@ BridgedType BridgedType::getFunctionTypeWithNoEscape(bool withNoEscape) const {
return swift::SILType::getPrimitiveObjectType(newTy);
}
BridgedArgumentConvention BridgedType::getCalleeConvention() const {
auto fnType = unbridged().getAs<swift::SILFunctionType>();
return getArgumentConvention(fnType->getCalleeConvention());
}
//===----------------------------------------------------------------------===//
// BridgedValue
//===----------------------------------------------------------------------===//
@@ -523,6 +532,11 @@ BridgedNullableVarDecl BridgedArgument::getVarDecl() const {
const_cast<swift::ValueDecl*>(getArgument()->getDecl()))};
}
void BridgedArgument::copyFlags(BridgedArgument fromArgument) const {
auto *fArg = static_cast<swift::SILFunctionArgument *>(getArgument());
fArg->copyFlags(static_cast<swift::SILFunctionArgument *>(fromArgument.getArgument()));
}
//===----------------------------------------------------------------------===//
// BridgedSubstitutionMap
//===----------------------------------------------------------------------===//
@@ -643,10 +657,6 @@ bool BridgedFunction::isAsync() const {
return getFunction()->isAsync();
}
bool BridgedFunction::isReabstractionThunk() const {
return getFunction()->isThunk() == swift::IsReabstractionThunk;
}
bool BridgedFunction::isGlobalInitFunction() const {
return getFunction()->isGlobalInit();
}
@@ -718,6 +728,10 @@ void BridgedFunction::setIsPerformanceConstraint(bool isPerfConstraint) const {
getFunction()->setIsPerformanceConstraint(isPerfConstraint);
}
void BridgedFunction::setLinkage(Linkage linkage) const {
getFunction()->setLinkage((swift::SILLinkage)linkage);
}
bool BridgedFunction::isResilientNominalDecl(BridgedNominalTypeDecl decl) const {
return decl.unbridged()->isResilient(getFunction()->getModule().getSwiftModule(),
getFunction()->getResilienceExpansion());
@@ -727,6 +741,10 @@ BridgedType BridgedFunction::getLoweredType(BridgedASTType type) const {
return BridgedType(getFunction()->getLoweredType(type.type));
}
BridgedType BridgedFunction::getLoweredType(BridgedType type) const {
return BridgedType(getFunction()->getLoweredType(type.unbridged()));
}
//===----------------------------------------------------------------------===//
// BridgedGlobalVar
//===----------------------------------------------------------------------===//
@@ -1097,6 +1115,10 @@ bool BridgedInstruction::PartialApplyInst_isOnStack() const {
return getAs<swift::PartialApplyInst>()->isOnStack();
}
bool BridgedInstruction::PartialApplyInst_hasUnknownResultIsolation() const {
return getAs<swift::PartialApplyInst>()->getResultIsolation() == swift::SILFunctionTypeIsolation::Unknown;
}
bool BridgedInstruction::AllocStackInst_hasDynamicLifetime() const {
return getAs<swift::AllocStackInst>()->hasDynamicLifetime();
}
@@ -1304,6 +1326,10 @@ SwiftInt BridgedInstruction::FullApplySite_numIndirectResultArguments() const {
return fas.getNumIndirectSILResults();
}
bool BridgedInstruction::ConvertFunctionInst_withoutActuallyEscaping() const {
return getAs<swift::ConvertFunctionInst>()->withoutActuallyEscaping();
}
//===----------------------------------------------------------------------===//
// VarDeclInst and DebugVariableInst
//===----------------------------------------------------------------------===//
@@ -1481,6 +1507,10 @@ BridgedWitnessTableEntryArray BridgedDefaultWitnessTable::getEntries() const {
// BridgedBuilder
//===----------------------------------------------------------------------===//
BridgedBuilder::BridgedBuilder(InsertAt insertAt, BridgedFunction function):
insertAt(insertAt), insertionObj(function.obj),
loc(BridgedLocation({function.getFunction()->getLocation(), function.getFunction()->getDebugScope()})) {}
BridgedInstruction BridgedBuilder::createBuiltinBinaryFunction(BridgedStringRef name,
BridgedType operandType, BridgedType resultType,
BridgedValueArray arguments) const {
@@ -1559,6 +1589,16 @@ BridgedInstruction BridgedBuilder::createEndInitLetRef(BridgedValue op) const {
return {unbridged().createEndInitLetRef(regularLoc(), op.getSILValue())};
}
BridgedInstruction BridgedBuilder::createRetainValue(BridgedValue op) const {
auto b = unbridged();
return {b.createRetainValue(regularLoc(), op.getSILValue(), b.getDefaultAtomicity())};
}
BridgedInstruction BridgedBuilder::createReleaseValue(BridgedValue op) const {
auto b = unbridged();
return {b.createReleaseValue(regularLoc(), op.getSILValue(), b.getDefaultAtomicity())};
}
BridgedInstruction BridgedBuilder::createStrongRetain(BridgedValue op) const {
auto b = unbridged();
return {b.createStrongRetain(regularLoc(), op.getSILValue(), b.getDefaultAtomicity())};
@@ -1687,6 +1727,24 @@ BridgedInstruction BridgedBuilder::createThinToThickFunction(BridgedValue fn, Br
resultType.unbridged())};
}
BridgedInstruction BridgedBuilder::createPartialApply(BridgedValue funcRef,
BridgedValueArray bridgedCapturedArgs,
BridgedArgumentConvention calleeConvention,
BridgedSubstitutionMap bridgedSubstitutionMap,
bool hasUnknownIsolation,
bool isOnStack) const {
llvm::SmallVector<swift::SILValue, 8> capturedArgs;
return {unbridged().createPartialApply(
regularLoc(),
funcRef.getSILValue(),
bridgedSubstitutionMap.unbridged(),
bridgedCapturedArgs.getValues(capturedArgs),
getParameterConvention(calleeConvention),
hasUnknownIsolation ? swift::SILFunctionTypeIsolation::Unknown : swift::SILFunctionTypeIsolation::Erased,
isOnStack ? swift:: PartialApplyInst::OnStack : swift::PartialApplyInst::NotOnStack
)};
}
BridgedInstruction BridgedBuilder::createBranch(BridgedBasicBlock destBlock, BridgedValueArray arguments) const {
llvm::SmallVector<swift::SILValue, 16> argValues;
return {unbridged().createBranch(regularLoc(), destBlock.unbridged(),
@@ -1810,6 +1868,14 @@ BridgedInstruction BridgedBuilder::createEndAccess(BridgedValue value) const {
return {unbridged().createEndAccess(regularLoc(), value.getSILValue(), false)};
}
BridgedInstruction BridgedBuilder::createConvertFunction(BridgedValue originalFunction, BridgedType resultType, bool withoutActuallyEscaping) const {
return {unbridged().createConvertFunction(regularLoc(), originalFunction.getSILValue(), resultType.unbridged(), withoutActuallyEscaping)};
}
BridgedInstruction BridgedBuilder::createConvertEscapeToNoEscape(BridgedValue originalFunction, BridgedType resultType, bool isLifetimeGuaranteed) const {
return {unbridged().createConvertEscapeToNoEscape(regularLoc(), originalFunction.getSILValue(), resultType.unbridged(), isLifetimeGuaranteed)};
}
SWIFT_END_NULLABILITY_ANNOTATIONS
#endif

View File

@@ -56,6 +56,7 @@ class SwiftPassInvocation;
class FixedSizeSlabPayload;
class FixedSizeSlab;
class SILVTable;
class ClosureSpecializationCloner;
}
struct BridgedPassContext;
@@ -180,6 +181,15 @@ struct BridgedCloner {
void clone(BridgedInstruction inst);
};
struct BridgedSpecializationCloner {
swift::ClosureSpecializationCloner * _Nonnull closureSpecCloner;
BridgedSpecializationCloner(BridgedFunction emptySpecializedFunction);
BridgedFunction getCloned() const;
BridgedBasicBlock getClonedBasicBlock(BridgedBasicBlock originalBasicBlock) const;
void cloneFunctionBody(BridgedFunction originalFunction, BridgedBasicBlock clonedEntryBlock, BridgedValueArray clonedEntryBlockArgs) const;
};
struct BridgedPassContext {
swift::SwiftPassInvocation * _Nonnull invocation;
@@ -324,6 +334,8 @@ struct BridgedPassContext {
BRIDGED_INLINE void beginTransformFunction(BridgedFunction function) const;
BRIDGED_INLINE void endTransformFunction() const;
BRIDGED_INLINE bool continueWithNextSubpassRun(OptionalBridgedInstruction inst) const;
BRIDGED_INLINE BridgedPassContext initializeNestedPassContext(BridgedFunction newFunction) const;
BRIDGED_INLINE void deinitializedNestedPassContext() const;
// SSAUpdater
@@ -347,6 +359,13 @@ struct BridgedPassContext {
BRIDGED_INLINE bool enableMoveInoutStackProtection() const;
BRIDGED_INLINE AssertConfiguration getAssertConfiguration() const;
bool enableSimplificationFor(BridgedInstruction inst) const;
// Closure specializer
SWIFT_IMPORT_UNSAFE BridgedFunction ClosureSpecializer_createEmptyFunctionWithSpecializedSignature(BridgedStringRef specializedName,
const BridgedParameterInfo * _Nullable specializedBridgedParams,
SwiftInt paramCount,
BridgedFunction bridgedApplySiteCallee,
bool isSerialized) const;
};
bool FullApplySite_canInline(BridgedInstruction apply);

View File

@@ -439,6 +439,14 @@ bool BridgedPassContext::continueWithNextSubpassRun(OptionalBridgedInstruction i
inst.unbridged(), invocation->getFunction(), invocation->getTransform());
}
BridgedPassContext BridgedPassContext::initializeNestedPassContext(BridgedFunction newFunction) const {
return { invocation->initializeNestedSwiftPassInvocation(newFunction.getFunction()) };
}
void BridgedPassContext::deinitializedNestedPassContext() const {
invocation->deinitializeNestedSwiftPassInvocation();
}
void BridgedPassContext::SSAUpdater_initialize(
BridgedFunction function, BridgedType type,
BridgedValue::Ownership ownership) const {

View File

@@ -74,6 +74,8 @@ class SwiftPassInvocation {
SILSSAUpdater *ssaUpdater = nullptr;
SwiftPassInvocation *nestedSwiftPassInvocation = nullptr;
static constexpr int BlockSetCapacity = SILBasicBlock::numCustomBits;
char blockSetStorage[sizeof(BasicBlockSet) * BlockSetCapacity];
bool aliveBlockSets[BlockSetCapacity];
@@ -181,6 +183,18 @@ public:
assert(ssaUpdater && "SSAUpdater not initialized");
return ssaUpdater;
}
SwiftPassInvocation *initializeNestedSwiftPassInvocation(SILFunction *newFunction) {
assert(!nestedSwiftPassInvocation && "Nested Swift pass invocation already initialized");
nestedSwiftPassInvocation = new SwiftPassInvocation(passManager, transform, newFunction);
return nestedSwiftPassInvocation;
}
void deinitializeNestedSwiftPassInvocation() {
assert(nestedSwiftPassInvocation && "Nested Swift pass invocation not initialized");
delete nestedSwiftPassInvocation;
nestedSwiftPassInvocation = nullptr;
}
};
/// The SIL pass manager.

View File

@@ -46,6 +46,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/GraphWriter.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Casting.h"
#include <fstream>
using namespace swift;
@@ -1939,6 +1940,68 @@ void BridgedPassContext::moveFunctionBody(BridgedFunction sourceFunc, BridgedFun
invocation->getPassManager()->invalidateAnalysis(destFn, SILAnalysis::InvalidationKind::Everything);
}
BridgedFunction BridgedPassContext::
ClosureSpecializer_createEmptyFunctionWithSpecializedSignature(BridgedStringRef specializedName,
const BridgedParameterInfo * _Nullable specializedBridgedParams,
SwiftInt paramCount,
BridgedFunction bridgedApplySiteCallee,
bool isSerialized) const {
auto *applySiteCallee = bridgedApplySiteCallee.getFunction();
auto applySiteCalleeType = applySiteCallee->getLoweredFunctionType();
llvm::SmallVector<SILParameterInfo> specializedParams;
for (unsigned idx = 0; idx < paramCount; ++idx) {
specializedParams.push_back(specializedBridgedParams[idx].unbridged());
}
// The specialized function is always a thin function. This is important
// because we may add additional parameters after the Self parameter of
// witness methods. In this case the new function is not a method anymore.
auto extInfo = applySiteCalleeType->getExtInfo();
extInfo = extInfo.withRepresentation(SILFunctionTypeRepresentation::Thin);
auto ClonedTy = SILFunctionType::get(
applySiteCalleeType->getInvocationGenericSignature(), extInfo,
applySiteCalleeType->getCoroutineKind(),
applySiteCalleeType->getCalleeConvention(), specializedParams,
applySiteCalleeType->getYields(), applySiteCalleeType->getResults(),
applySiteCalleeType->getOptionalErrorResult(),
applySiteCalleeType->getPatternSubstitutions(),
applySiteCalleeType->getInvocationSubstitutions(),
applySiteCallee->getModule().getASTContext());
SILOptFunctionBuilder functionBuilder(*invocation->getTransform());
// We make this function bare so we don't have to worry about decls in the
// SILArgument.
auto *specializedApplySiteCallee = functionBuilder.createFunction(
// It's important to use a shared linkage for the specialized function
// and not the original linkage.
// Otherwise the new function could have an external linkage (in case the
// original function was de-serialized) and would not be code-gen'd.
// It's also important to disconnect this specialized function from any
// classes (the classSubclassScope), because that may incorrectly
// influence the linkage.
getSpecializedLinkage(applySiteCallee, applySiteCallee->getLinkage()), specializedName.unbridged(),
ClonedTy, applySiteCallee->getGenericEnvironment(),
applySiteCallee->getLocation(), IsBare, applySiteCallee->isTransparent(),
isSerialized ? IsSerialized : IsNotSerialized, IsNotDynamic, IsNotDistributed,
IsNotRuntimeAccessible, applySiteCallee->getEntryCount(),
applySiteCallee->isThunk(),
/*classSubclassScope=*/SubclassScope::NotApplicable,
applySiteCallee->getInlineStrategy(), applySiteCallee->getEffectsKind(),
applySiteCallee, applySiteCallee->getDebugScope());
if (!applySiteCallee->hasOwnership()) {
specializedApplySiteCallee->setOwnershipEliminated();
}
for (auto &Attr : applySiteCallee->getSemanticsAttrs())
specializedApplySiteCallee->addSemanticsAttr(Attr);
return {specializedApplySiteCallee};
}
bool FullApplySite_canInline(BridgedInstruction apply) {
return swift::SILInliner::canInlineApplySite(
swift::FullApplySite(apply.unbridged()));
@@ -2070,3 +2133,39 @@ void SILPassManager::runSwiftModuleVerification() {
runSwiftFunctionVerification(&f);
}
}
namespace swift {
class ClosureSpecializationCloner: public SILClonerWithScopes<ClosureSpecializationCloner> {
friend class SILInstructionVisitor<ClosureSpecializationCloner>;
friend class SILCloner<ClosureSpecializationCloner>;
public:
using SuperTy = SILClonerWithScopes<ClosureSpecializationCloner>;
ClosureSpecializationCloner(SILFunction &emptySpecializedFunction): SuperTy(emptySpecializedFunction) {}
};
} // namespace swift
BridgedSpecializationCloner::BridgedSpecializationCloner(BridgedFunction emptySpecializedFunction):
closureSpecCloner(new ClosureSpecializationCloner(*emptySpecializedFunction.getFunction())) {}
BridgedFunction BridgedSpecializationCloner::getCloned() const {
return { &closureSpecCloner->getBuilder().getFunction() };
}
BridgedBasicBlock BridgedSpecializationCloner::getClonedBasicBlock(BridgedBasicBlock originalBasicBlock) const {
return { closureSpecCloner->getOpBasicBlock(originalBasicBlock.unbridged()) };
}
void BridgedSpecializationCloner::cloneFunctionBody(BridgedFunction originalFunction, BridgedBasicBlock clonedEntryBlock, BridgedValueArray clonedEntryBlockArgs) const {
llvm::SmallVector<swift::SILValue, 16> clonedEntryBlockArgsStorage;
auto clonedEntryBlockArgsArrayRef = clonedEntryBlockArgs.getValues(clonedEntryBlockArgsStorage);
closureSpecCloner->cloneFunctionBody(originalFunction.getFunction(), clonedEntryBlock.unbridged(), clonedEntryBlockArgsArrayRef);
}
void BridgedBuilder::destroyCapturedArgs(BridgedInstruction partialApply) const {
if (auto *pai = llvm::dyn_cast<PartialApplyInst>(partialApply.unbridged()); pai->isOnStack()) {
auto b = unbridged();
return swift::insertDestroyOfCapturedArguments(pai, b);
} else {
assert(false && "`destroyCapturedArgs` must only be called on a `partial_apply` on stack!");
}
}

View File

@@ -10,8 +10,6 @@ import SwiftShims
import _Differentiation
// ===================== Gathering callsites and corresponding closures ===================== //
//////////////////////////////
// Single closure call site //
//////////////////////////////
@@ -34,12 +32,25 @@ bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> (Float, Float)):
// reverse-mode derivative of f(_:)
sil hidden @$s4test1fyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
//=========== Test callsite and closure gathering logic ===========//
specify_test "closure_specialize_gather_call_sites"
// CHECK-LABEL: Specializing closures in function: $s4test1fyS2fFTJrSpSr
// CHECK: PartialApply call site: %8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float // user: %9
// CHECK: Passed in closures:
// CHECK: 1. %6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %8
//=========== Test specialized function signature and body ===========//
specify_test "closure_specialize_specialized_function_signature_and_body"
// CHECK-LABEL: Generated specialized function: $s11$pullback_f12$vjpMultiplyS2fTf1nc_n
// CHECK: sil private @$s11$pullback_f12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float {
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float):
// CHECK: %3 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %4
// CHECK: %4 = partial_apply [callee_guaranteed] %3(%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // users: %6, %5
// CHECK: %5 = apply %4(%0) : $@callee_guaranteed (Float) -> (Float, Float) // users: %8, %7
// CHECK: strong_release %4 : $@callee_guaranteed (Float) -> (Float, Float) // id: %6
// CHECK: return
debug_value %0 : $Float, let, name "x", argno 1 // id: %1
%2 = struct_extract %0 : $Float, #Float._value // users: %3, %3
%3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %4
@@ -57,9 +68,8 @@ bb0(%0 : $Float):
///////////////////////////////
// Multiple closure callsite //
///////////////////////////////
sil @$_vjpSin : $@convention(thin) (Float, Float) -> Float // user: %6
sil @$_vjpCos : $@convention(thin) (Float, Float) -> Float // user: %10
sil @$_vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
sil @$vjpSin : $@convention(thin) (Float, Float) -> Float // user: %6
sil @$vjpCos : $@convention(thin) (Float, Float) -> Float // user: %10
// pullback of g(_:)
sil private @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float {
@@ -83,6 +93,7 @@ bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float, %2 : $@callee_guaran
// reverse-mode derivative of g(_:)
sil hidden @$s4test1gyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
//=========== Test callsite and closure gathering logic ===========//
specify_test "closure_specialize_gather_call_sites"
// CHECK-LABEL: Specializing closures in function: $s4test1gyS2fFTJrSpSr
// CHECK: PartialApply call site: %16 = partial_apply [callee_guaranteed] %15(%6, %10, %14) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float // user: %17
@@ -91,22 +102,41 @@ bb0(%0 : $Float):
// CHECK: 2. %10 = partial_apply [callee_guaranteed] %9(%0) : $@convention(thin) (Float, Float) -> Float // user: %16
// CHECK: 3. %14 = partial_apply [callee_guaranteed] %13(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %16
//=========== Test specialized function signature and body ===========//
specify_test "closure_specialize_specialized_function_signature_and_body"
// CHECK-LABEL: Generated specialized function: $s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n
// CHECK: sil private @$s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n : $@convention(thin) (Float, Float, Float, Float, Float) -> Float {
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float, %3 : $Float, %4 : $Float):
// CHECK: %5 = function_ref @$vjpSin : $@convention(thin) (Float, Float) -> Float // user: %6
// CHECK: %6 = partial_apply [callee_guaranteed] %5(%1) : $@convention(thin) (Float, Float) -> Float // users: %18, %17
// CHECK: %7 = function_ref @$vjpCos : $@convention(thin) (Float, Float) -> Float // user: %8
// CHECK: %8 = partial_apply [callee_guaranteed] %7(%2) : $@convention(thin) (Float, Float) -> Float // users: %16, %15
// CHECK: %9 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %10
// CHECK: %10 = partial_apply [callee_guaranteed] %9(%3, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // users: %12, %11
// CHECK: %11 = apply %10(%0) : $@callee_guaranteed (Float) -> (Float, Float) // users: %14, %13
// CHECK: strong_release %10 : $@callee_guaranteed (Float) -> (Float, Float) // id: %12
// CHECK: %15 = apply %8(%14) : $@callee_guaranteed (Float) -> Float // user: %19
// CHECK: strong_release %8 : $@callee_guaranteed (Float) -> Float // id: %16
// CHECK: %17 = apply %6(%13) : $@callee_guaranteed (Float) -> Float // user: %20
// CHECK: strong_release %6 : $@callee_guaranteed (Float) -> Float // id: %18
// CHECK: return
debug_value %0 : $Float, let, name "x", argno 1 // id: %1
%2 = struct_extract %0 : $Float, #Float._value // users: %7, %3
%3 = builtin "int_sin_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // users: %11, %4
%4 = struct $Float (%3 : $Builtin.FPIEEE32) // user: %14
// function_ref closure #1 in _vjpSin(_:)
%5 = function_ref @$_vjpSin : $@convention(thin) (Float, Float) -> Float // user: %6
%5 = function_ref @$vjpSin : $@convention(thin) (Float, Float) -> Float // user: %6
%6 = partial_apply [callee_guaranteed] %5(%0) : $@convention(thin) (Float, Float) -> Float // user: %16
%7 = builtin "int_cos_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // users: %11, %8
%8 = struct $Float (%7 : $Builtin.FPIEEE32) // user: %14
// function_ref closure #1 in _vjpCos(_:)
%9 = function_ref @$_vjpCos : $@convention(thin) (Float, Float) -> Float // user: %10
%9 = function_ref @$vjpCos : $@convention(thin) (Float, Float) -> Float // user: %10
%10 = partial_apply [callee_guaranteed] %9(%0) : $@convention(thin) (Float, Float) -> Float // user: %16
%11 = builtin "fmul_FPIEEE32"(%3 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %12
%12 = struct $Float (%11 : $Builtin.FPIEEE32) // user: %17
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%13 = function_ref @$_vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %14
%13 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %14
%14 = partial_apply [callee_guaranteed] %13(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %16
// function_ref pullback of g(_:)
%15 = function_ref @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float // user: %16
@@ -114,3 +144,199 @@ bb0(%0 : $Float):
%17 = tuple (%12 : $Float, %16 : $@callee_guaranteed (Float) -> Float) // user: %18
return %17 : $(Float, @callee_guaranteed (Float) -> Float) // id: %18
}
///////////////////////////////
/// Parameter subset thunks ///
///////////////////////////////
struct X : Differentiable {
@_hasStorage var a: Float { get set }
@_hasStorage var b: Double { get set }
struct TangentVector : AdditiveArithmetic, Differentiable {
@_hasStorage var a: Float { get set }
@_hasStorage var b: Double { get set }
static func + (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector
static func - (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X.TangentVector, _ b: X.TangentVector) -> Bool
typealias TangentVector = X.TangentVector
init(a: Float, b: Double)
static var zero: X.TangentVector { get }
}
init(a: Float, b: Double)
mutating func move(by offset: X.TangentVector)
}
sil [transparent] [thunk] @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
sil @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector
sil shared @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> X.TangentVector):
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> X.TangentVector
strong_release %1 : $@callee_guaranteed (Float) -> X.TangentVector
return %2 : $X.TangentVector
}
sil hidden @$s5test21g1xSfAA1XV_tFTJrSpSr : $@convention(thin) (X) -> (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) {
bb0(%0 : $X):
//=========== Test callsite and closure gathering logic ===========//
specify_test "closure_specialize_gather_call_sites"
// CHECK-LABEL: Specializing closures in function: $s5test21g1xSfAA1XV_tFTJrSpSr
// CHECK: PartialApply call site: %7 = partial_apply [callee_guaranteed] %6(%5) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector // user: %8
// CHECK: Passed in closures:
// CHECK: 1. %3 = thin_to_thick_function %2 : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector // user: %5
//=========== Test specialized function signature and body ===========//
specify_test "closure_specialize_specialized_function_signature_and_body"
// CHECK-LABEL: Generated specialized function: $s10pullback_g0A2_fTf1nc_n
// CHECK: sil shared @$s10pullback_g0A2_fTf1nc_n : $@convention(thin) (Float) -> X.TangentVector {
// CHECK: bb0(%0 : $Float):
// CHECK: %1 = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector // user: %2
// CHECK: %2 = thin_to_thick_function %1 : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector // user: %4
// CHECK: %3 = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector // user: %4
// CHECK: %4 = partial_apply [callee_guaranteed] %3(%2) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector // users: %6, %5
// CHECK: %5 = apply %4(%0) : $@callee_guaranteed (Float) -> X.TangentVector // user: %7
// CHECK: strong_release %4 : $@callee_guaranteed (Float) -> X.TangentVector // id: %6
// CHECK: return %5 : $X.TangentVector // id: %7
%1 = struct_extract %0 : $X, #X.a // user: %8
// function_ref pullback_f
%2 = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector // user: %3
%3 = thin_to_thick_function %2 : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector // user: %5
// function_ref subset_parameter_thunk
%4 = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector // user: %5
%5 = partial_apply [callee_guaranteed] %4(%3) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector // user: %7
// function_ref pullback_g
%6 = function_ref @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector // user: %7
%7 = partial_apply [callee_guaranteed] %6(%5) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector // user: %8
%8 = tuple (%1 : $Float, %7 : $@callee_guaranteed (Float) -> X.TangentVector) // user: %9
return %8 : $(Float, @callee_guaranteed (Float) -> X.TangentVector) // id: %9
}
///////////////////////////////////////////////////////////////////////
///////// Specialized generic closures - PartialApply Closure /////////
///////////////////////////////////////////////////////////////////////
// closure #1 in static Float._vjpMultiply(lhs:rhs:)
sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
sil [transparent] [reabstraction_thunk] @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
// function_ref specialized pullback of f<A>(a:)
sil [transparent] [thunk] @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
// thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float)
sil [transparent] [reabstraction_thunk] @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
sil private [signature_optimized_thunk] [always_inline] @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float):
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> Float
strong_release %1 : $@callee_guaranteed (Float) -> Float
return %2 : $Float
}
// reverse-mode derivative of h(x:)
sil hidden @$s5test21h1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
//=========== Test callsite and closure gathering logic ===========//
specify_test "closure_specialize_gather_call_sites"
// CHECK-LABEL: Specializing closures in function: $s5test21h1xS2f_tFTJrSpSr
// CHECK: PartialApply call site: %14 = partial_apply [callee_guaranteed] %13(%11) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %15
// CHECK: Passed in closures:
// CHECK: 1. %4 = partial_apply [callee_guaranteed] %3(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %6
//=========== Test specialized function signature and body ===========//
specify_test "closure_specialize_specialized_function_signature_and_body"
// CHECK-LABEL: Generated specialized function: $s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n
// CHECK: sil private [signature_optimized_thunk] [always_inline] @$s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float {
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float):
// CHECK: %3 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %4
// CHECK: %4 = partial_apply [callee_guaranteed] %3(%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %6
// CHECK: %5 = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) // user: %6
// CHECK: %6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) // user: %7
// CHECK: %7 = convert_function %6 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float> // user: %9
// CHECK: %8 = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float // user: %9
// CHECK: %9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float // user: %11
// CHECK: %10 = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // user: %11
// CHECK: %11 = partial_apply [callee_guaranteed] %10(%9) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // users: %13, %12
// CHECK: %12 = apply %11(%0) : $@callee_guaranteed (Float) -> Float // user: %14
// CHECK: strong_release %11 : $@callee_guaranteed (Float) -> Float // id: %13
// CHECK: return %12 : $Float
%1 = struct_extract %0 : $Float, #Float._value // users: %2, %2
%2 = builtin "fmul_FPIEEE32"(%1 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %12
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%3 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %4
%4 = partial_apply [callee_guaranteed] %3(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %6
// function_ref thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
%5 = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) // user: %6
%6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) // user: %7
%7 = convert_function %6 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float> // user: %9
// function_ref pullback_f_specialized
%8 = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float // user: %9
%9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float // user: %11
// function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float)
%10 = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // user: %11
%11 = partial_apply [callee_guaranteed] %10(%9) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // user: %14
%12 = struct $Float (%2 : $Builtin.FPIEEE32) // user: %15
// function_ref pullback_h
%13 = function_ref @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %14
%14 = partial_apply [callee_guaranteed] %13(%11) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %15
%15 = tuple (%12 : $Float, %14 : $@callee_guaranteed (Float) -> Float) // user: %16
return %15 : $(Float, @callee_guaranteed (Float) -> Float) // id: %16
}
//////////////////////////////////////////////////////////////////////////////
///////// Specialized generic closures - ThinToThickFunction closure /////////
//////////////////////////////////////////////////////////////////////////////
sil [transparent] [thunk] @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float
sil [transparent] [reabstraction_thunk] @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
sil private [signature_optimized_thunk] [always_inline] @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float):
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> Float
strong_release %1 : $@callee_guaranteed (Float) -> Float
return %2 : $Float
}
sil hidden @$s5test21z1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
//=========== Test callsite and closure gathering logic ===========//
specify_test "closure_specialize_gather_call_sites"
// CHECK-LABEL: Specializing closures in function: $s5test21z1xS2f_tFTJrSpSr
// CHECK: PartialApply call site: %6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %7
// CHECK: Passed in closures:
// CHECK: 1. %2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float // user: %4
//=========== Test specialized function signature and body ===========//
specify_test "closure_specialize_specialized_function_signature_and_body"
// CHECK-LABEL: Generated specialized function: $s10pullback_z0A14_y_specializedTf1nc_n
// CHECK: sil private [signature_optimized_thunk] [always_inline] @$s10pullback_z0A14_y_specializedTf1nc_n : $@convention(thin) (Float) -> Float {
// CHECK: bb0(%0 : $Float):
// CHECK: %1 = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float // user: %2
// CHECK: %2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float // user: %4
// CHECK: %3 = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // user: %4
// CHECK: %4 = partial_apply [callee_guaranteed] %3(%2) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // users: %6, %5
// CHECK: %5 = apply %4(%0) : $@callee_guaranteed (Float) -> Float // user: %7
// CHECK: strong_release %4 : $@callee_guaranteed (Float) -> Float // id: %6
// CHECK: return %5 : $Float
// function_ref pullback_y_specialized
%1 = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float // user: %2
%2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float // user: %4
// function_ref reabstraction_thunk
%3 = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // user: %4
%4 = partial_apply [callee_guaranteed] %3(%2) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // user: %6
// function_ref pullback_z
%5 = function_ref @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %6
%6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %7
%7 = tuple (%0 : $Float, %6 : $@callee_guaranteed (Float) -> Float) // user: %8
return %7 : $(Float, @callee_guaranteed (Float) -> Float) // id: %8
}

View File

@@ -1,173 +0,0 @@
// RUN: %target-sil-opt -test-runner %s -o /dev/null 2>&1 | %FileCheck %s
// REQUIRES: swift_in_compiler
// XFAIL: *
sil_stage canonical
import Builtin
import Swift
import SwiftShims
import _Differentiation
// ===================== Gathering callsites and corresponding closures ===================== //
///////////////////////////////
/// Parameter subset thunks ///
///////////////////////////////
struct X : Differentiable {
@_hasStorage var a: Float { get set }
@_hasStorage var b: Double { get set }
struct TangentVector : AdditiveArithmetic, Differentiable {
@_hasStorage var a: Float { get set }
@_hasStorage var b: Double { get set }
static func + (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector
static func - (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X.TangentVector, _ b: X.TangentVector) -> Bool
typealias TangentVector = X.TangentVector
init(a: Float, b: Double)
static var zero: X.TangentVector { get }
}
init(a: Float, b: Double)
mutating func move(by offset: X.TangentVector)
}
sil [transparent] [thunk] @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
sil @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector
sil @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector
sil hidden @$s5test21g1xSfAA1XV_tFTJrSpSr : $@convention(thin) (X) -> (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) {
bb0(%0 : $X):
specify_test "closure_specialize_gather_call_sites"
// CHECK-LABEL: Specializing closures in function: $s5test21g1xSfAA1XV_tFTJrSpSr
// CHECK: PartialApply call site: %9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector // user: %10
// CHECK: Passed in closures:
// CHECK: 1. %7 = partial_apply [callee_guaranteed] %6(%5) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector // user: %9
%3 = struct_extract %0 : $X, #X.a // user: %10
%4 = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector // user: %5
%5 = thin_to_thick_function %4 : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector // user: %7
%6 = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector // user: %7
%7 = partial_apply [callee_guaranteed] %6(%5) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector // user: %9
%8 = function_ref @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector // user: %9
%9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector // user: %10
%10 = tuple (%3 : $Float, %9 : $@callee_guaranteed (Float) -> X.TangentVector) // user: %11
return %10 : $(Float, @callee_guaranteed (Float) -> X.TangentVector) // id: %11
}
///////////////////////////////////////////////////////////////////////
///////// Specialized generic closures - PartialApply Closure /////////
///////////////////////////////////////////////////////////////////////
// closure #1 in static Float._vjpMultiply(lhs:rhs:)
sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
sil [transparent] [reabstraction_thunk] @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
// function_ref specialized pullback of f<A>(a:)
sil [transparent] [thunk] @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
// thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float)
sil [transparent] [reabstraction_thunk] @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
sil private [signature_optimized_thunk] [always_inline] @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float):
%2 = integer_literal $Builtin.Int64, 0 // user: %3
%3 = builtin "sitofp_Int64_FPIEEE32"(%2 : $Builtin.Int64) : $Builtin.FPIEEE32 // users: %10, %5
%4 = struct_extract %0 : $Float, #Float._value // user: %5
%5 = builtin "fadd_FPIEEE32"(%3 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %6
%6 = struct $Float (%5 : $Builtin.FPIEEE32) // user: %7
%7 = apply %1(%6) : $@callee_guaranteed (Float) -> Float // user: %9
strong_release %1 : $@callee_guaranteed (Float) -> Float // id: %8
%9 = struct_extract %7 : $Float, #Float._value // user: %10
%10 = builtin "fadd_FPIEEE32"(%3 : $Builtin.FPIEEE32, %9 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %11
%11 = struct $Float (%10 : $Builtin.FPIEEE32) // users: %13, %12
debug_value %11 : $Float, let, name "x", argno 1 // id: %12
return %11 : $Float // id: %13
}
// reverse-mode derivative of h(x:)
sil hidden @$s5test21h1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
specify_test "closure_specialize_gather_call_sites"
// CHECK-LABEL: Specializing closures in function: $s5test21h1xS2f_tFTJrSpSr
// CHECK: PartialApply call site: %14 = partial_apply [callee_guaranteed] %13(%11) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %15
// CHECK: Passed in closures:
// CHECK: 1. %9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float // user: %11
%1 = struct_extract %0 : $Float, #Float._value // users: %2, %2
%2 = builtin "fmul_FPIEEE32"(%1 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %12
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%3 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %4
%4 = partial_apply [callee_guaranteed] %3(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %6
// function_ref thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
%5 = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) // user: %6
%6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float) // user: %7
%7 = convert_function %6 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float> // user: %9
// function_ref pullback_f_specialized
%8 = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float // user: %9
%9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float // user: %11
// function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float)
%10 = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // user: %11
%11 = partial_apply [callee_guaranteed] %10(%9) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // user: %14
%12 = struct $Float (%2 : $Builtin.FPIEEE32) // user: %15
// function_ref pullback_h
%13 = function_ref @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %14
%14 = partial_apply [callee_guaranteed] %13(%11) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %15
%15 = tuple (%12 : $Float, %14 : $@callee_guaranteed (Float) -> Float) // user: %16
return %15 : $(Float, @callee_guaranteed (Float) -> Float) // id: %16
}
//////////////////////////////////////////////////////////////////////////////
///////// Specialized generic closures - ThinToThickFunction closure /////////
//////////////////////////////////////////////////////////////////////////////
sil [transparent] [thunk] @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float
sil [transparent] [reabstraction_thunk] @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
sil private [signature_optimized_thunk] [always_inline] @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float):
%2 = integer_literal $Builtin.Int64, 0 // user: %3
%3 = builtin "sitofp_Int64_FPIEEE32"(%2 : $Builtin.Int64) : $Builtin.FPIEEE32 // users: %10, %5
%4 = struct_extract %0 : $Float, #Float._value // user: %5
%5 = builtin "fadd_FPIEEE32"(%3 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %6
%6 = struct $Float (%5 : $Builtin.FPIEEE32) // user: %7
%7 = apply %1(%6) : $@callee_guaranteed (Float) -> Float // user: %9
strong_release %1 : $@callee_guaranteed (Float) -> Float // id: %8
%9 = struct_extract %7 : $Float, #Float._value // user: %10
%10 = builtin "fadd_FPIEEE32"(%3 : $Builtin.FPIEEE32, %9 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %11
%11 = struct $Float (%10 : $Builtin.FPIEEE32) // users: %13, %12
debug_value %11 : $Float, let, name "x", argno 1 // id: %12
return %11 : $Float // id: %13
}
sil hidden @$s5test21z1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
specify_test "closure_specialize_gather_call_sites"
// CHECK-LABEL: Specializing closures in function: $s5test21z1xS2f_tFTJrSpSr
// CHECK: PartialApply call site: %6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %7
// CHECK: Passed in closures:
// CHECK: 1. %2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float // user: %4
// function_ref pullback_y_specialized
%1 = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float // user: %2
%2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float // user: %4
// function_ref reabstraction_thunk
%3 = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // user: %4
%4 = partial_apply [callee_guaranteed] %3(%2) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float // user: %6
// function_ref pullback_z
%5 = function_ref @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %6
%6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float // user: %7
%7 = tuple (%0 : $Float, %6 : $@callee_guaranteed (Float) -> Float) // user: %8
return %7 : $(Float, @callee_guaranteed (Float) -> Float) // id: %8
}