Files
Daniil Kovalev d5eba5d245 [AutoDiff] SILCombine: handle convert_function use of differentiable_function (#88919)
The patch enhances sil-combiner handling of `differentiable_function` by
adding support for `convert_function` which is further used in
`differentiable_function_extract`:

```
  %0 = differentiable_function ... %x
  %1 = begin_borrow %0
  %2 = convert_function %1 to ...
  %3 = differentiable_function_extract [xxx] %2
  // use of %3
```

-->

```
  %0 = differentiable_function ... %x
  // use of %x
```
2026-05-21 18:47:22 +00:00

185 lines
6.6 KiB
Swift

//===--- SimplifyDifferentiableFunction.swift -----------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2026 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
import SIL
extension DifferentiableFunctionInst : SILCombineSimplifiable {
/// Eliminates `differentiable_function_extract`s of an owned `differentiable_function` where the
/// `differentiable_function_extract`s are inside a borrow scope.
/// This is done by splitting the `begin_borrow` of the whole `differentiable_function` into individual borrows of the extractees
/// (for trivial extractees no borrow is needed). If needed, `convert_function` is emitted to cast the extractee to the
/// ABI-compatible expected type.
///
/// ```
/// %3 = differentiable_function [parameters X] [results X] %orig with_derivative { %jvp, %vjp }
/// ...
/// %4 = begin_borrow %3
/// %5 = differentiable_function_extract [original] %4
/// %6 = differentiable_function_extract [jvp] %4
/// %7 = differentiable_function_extract [vjp] %4
/// use %5, %6, %7
/// end_borrow %4
/// ...
/// end_of_lifetime %3
/// ```
/// ->
/// ```
/// %3 = differentiable_function [parameters X] [results X] %orig with_derivative { %jvp, %vjp }
/// ...
/// %5 = begin_borrow %orig
/// %5x = convert_function %5 to XXX
/// use %5x
/// end_borrow %5
/// ...
/// %6 = begin_borrow %jvp
/// %6x = convert_function %6 to XXX
/// use %6x
/// end_borrow %6
/// ...
/// %7 = begin_borrow %vjp
/// %7x = convert_function %7 to XXX
/// use %7x
/// end_borrow %7
/// ...
/// end_of_lifetime %3
/// ```
func simplify(_ context: SimplifyContext) {
guard ownership == .owned,
hasOnlyExtractUsesInBorrowScopes()
else {
return
}
for use in uses {
switch use.instruction {
case let beginBorrow as BeginBorrowInst:
splitAndRemoveExtracts(beginBorrow: beginBorrow, context)
case is DebugValueInst:
break
default:
assert(use.endsLifetime)
}
}
}
private func hasOnlyExtractUsesInBorrowScopes() -> Bool {
var hasExtract = false
for use in uses.ignoreDebugUses {
switch use.instruction {
case let beginBorrow as BeginBorrowInst:
for borrowUse in beginBorrow.uses.ignoreDebugUses {
switch borrowUse.instruction {
case is EndBorrowInst:
break
case is DifferentiableFunctionExtractInst:
hasExtract = true
case let convert as ConvertFunctionInst:
for convertUse in convert.uses.ignoreDebugUses {
switch convertUse.instruction {
case is DifferentiableFunctionExtractInst:
hasExtract = true
default:
return false
}
}
default:
return false
}
}
default:
guard use.endsLifetime else {
return false
}
}
}
return hasExtract
}
private func processExtract(differentiableFunctionExtract: DifferentiableFunctionExtractInst, beginBorrow: BeginBorrowInst, _ context: SimplifyContext) {
guard let extractee = self.getExtractee(extractee: differentiableFunctionExtract.extractee)
else {
return
}
// If the extractee has non-trivial ownership, it is consumed by the differentiable_function instruction.
// We must copy it before the consumption point so the copy remains live afterward.
let effectiveExtractee: Value
let needsDestroy: Bool
if extractee.ownership != .none {
let copyBuilder = Builder(before: self, context)
effectiveExtractee = copyBuilder.createCopyValue(operand: extractee)
needsDestroy = true
} else {
effectiveExtractee = extractee
needsDestroy = false
}
switch differentiableFunctionExtract.ownership {
case .none:
if differentiableFunctionExtract.type != effectiveExtractee.type {
let convertBuilder = Builder(before: differentiableFunctionExtract, context)
let newField = convertBuilder.createConvertFunction(
originalFunction: effectiveExtractee, resultType: differentiableFunctionExtract.type,
withoutActuallyEscaping: false)
differentiableFunctionExtract.replace(with: newField, context)
} else {
differentiableFunctionExtract.replace(with: effectiveExtractee, context)
}
case .guaranteed:
let beginBuilder = Builder(before: beginBorrow, context)
let borrowedField = beginBuilder.createBeginBorrow(
of: effectiveExtractee,
isLexical: beginBorrow.isLexical,
hasPointerEscape: beginBorrow.hasPointerEscape)
if differentiableFunctionExtract.type != effectiveExtractee.type {
let convertBuilder = Builder(before: differentiableFunctionExtract, context)
let newField = convertBuilder.createConvertFunction(
originalFunction: borrowedField, resultType: differentiableFunctionExtract.type,
withoutActuallyEscaping: false)
differentiableFunctionExtract.replace(with: newField, context)
} else {
differentiableFunctionExtract.replace(with: borrowedField, context)
}
for endBorrow in beginBorrow.endInstructions {
let endBuilder = Builder(before: endBorrow, context)
endBuilder.createEndBorrow(of: borrowedField)
if needsDestroy {
endBuilder.createDestroyValue(operand: effectiveExtractee)
}
}
case .owned, .unowned:
fatalError("wrong ownership of differentiable_function_extract")
}
}
private func splitAndRemoveExtracts(beginBorrow: BeginBorrowInst, _ context: SimplifyContext) {
for differentiableFunctionExtract in beginBorrow.uses.users(ofType: DifferentiableFunctionExtractInst.self) {
processExtract(
differentiableFunctionExtract: differentiableFunctionExtract,
beginBorrow: beginBorrow, context)
}
for convertFunction in beginBorrow.uses.users(ofType: ConvertFunctionInst.self) {
for differentiableFunctionExtract in convertFunction.uses.users(ofType: DifferentiableFunctionExtractInst.self) {
processExtract(
differentiableFunctionExtract: differentiableFunctionExtract,
beginBorrow: beginBorrow, context)
}
}
}
}