Add `Differentiable` requirements to pattern substitutions / pattern generic signature when calculating constrained function type. Also, add requirements for differentiable results as well.
Fixes#65487
Introduce the notion of "semantic result parameter". Handle differentiation of inouts via semantic result parameter abstraction. Do not consider non-wrt semantic result parameters as semantic results
Fixes#67174
Reformatting everything now that we have `llvm` namespaces. I've
separated this from the main commit to help manage merge-conflicts and
for making it a bit easier to read the mega-patch.
This is phase-1 of switching from llvm::Optional to std::optional in the
next rebranch. llvm::Optional was removed from upstream LLVM, so we need
to migrate off rather soon. On Darwin, std::optional, and llvm::Optional
have the same layout, so we don't need to be as concerned about ABI
beyond the name mangling. `llvm::Optional` is only returned from one
function in
```
getStandardTypeSubst(StringRef TypeName,
bool allowConcurrencyManglings);
```
It's the return value, so it should not impact the mangling of the
function, and the layout is the same as `std::optional`, so it should be
mostly okay. This function doesn't appear to have users, and the ABI was
already broken 2 years ago for concurrency and no one seemed to notice
so this should be "okay".
I'm doing the migration incrementally so that folks working on main can
cherry-pick back to the release/5.9 branch. Once 5.9 is done and locked
away, then we can go through and finish the replacement. Since `None`
and `Optional` show up in contexts where they are not `llvm::None` and
`llvm::Optional`, I'm preparing the work now by going through and
removing the namespace unwrapping and making the `llvm` namespace
explicit. This should make it fairly mechanical to go through and
replace llvm::Optional with std::optional, and llvm::None with
std::nullopt. It's also a change that can be brought onto the
release/5.9 with minimal impact. This should be an NFC change.
Compiler:
- Add `Forward` and `Reverse` to `DifferentiabilityKind`.
- Expand `DifferentiabilityMask` in `ExtInfo` to 3 bits so that it now holds all 4 cases of `DifferentiabilityKind`.
- Parse `@differentiable(reverse)` and `@differentiable(_forward)` declaration attributes and type attributes.
- Emit a warning for `@differentiable` without `reverse`.
- Emit an error for `@differentiable(_forward)`.
- Rename `@differentiable(linear)` to `@differentiable(_linear)`.
- Make `@differentiable(reverse)` type lowering go through today's `@differentiable` code path. We will specialize it to reverse-mode in a follow-up patch.
ABI:
- Add `Forward` and `Reverse` to `FunctionMetadataDifferentiabilityKind`.
- Extend `TargetFunctionTypeFlags` by 1 bit to store the highest bit of differentiability kind (linear). Note that there is a 2-bit gap in `DifferentiabilityMask` which is reserved for `AsyncMask` and `ConcurrentMask`; `AsyncMask` is ABI-stable so we cannot change that.
_Differentiation module:
- Replace all occurrences of `@differentiable` with `@differentiable(reverse)`.
- Delete `_transpose(of:)`.
Resolves rdar://69980056.
- `Mangle::ASTMangler::mangleAutoDiffDerivativeFunction()` and `Mangle::ASTMangler::mangleAutoDiffLinearMap()` accept original function declarations and return a mangled name for a derivative function or linear map. This is called during SILGen and TBDGen.
- `Mangle::DifferentiationMangler` handles differentiation function mangling in the differentiation transform. This part is necessary because we need to perform demangling on the original function and remangle it as part of a differentiation function mangling tree in order to get the correct substitutions in the mangled derivative generic signature.
A mangled differentiation function name includes:
- The original function.
- The differentiation function kind.
- The parameter indices for differentiation.
- The result indices for differentiation.
- The derivative generic signature.
Add request that resolves the "tangent stored property" corresponding to an
original stored property in a `Differentiable`-conforming type.
Enables better non-differentiability differentiation transform diagnostics.
`DifferentiableFunctionInst` now stores result indices.
`SILAutoDiffIndices` now stores result indices instead of a source index.
`@differentiable` SIL function types may now have multiple differentiability
result indices and `@noDerivative` resutls.
`@differentiable` AST function types do not have `@noDerivative` results (yet),
so this functionality is not exposed to users.
Resolves TF-689 and TF-1256.
Infrastructural support for TF-983: supporting differentiation of `apply`
instructions with multiple active semantic results.
Remove all assertions from `AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType`.
All error cases are represented by `DerivativeFunctionTypeError` now.
Fix `DerivativeFunctionTypeError` error payloads and improve error case naming.
Create `DerivativeFunctionTypeError` representing all potential derivative
function type calculation errors.
Make `AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType` return
`llvm::Expected<AnyFunctionType *>`. This is much safer and caller-friendly
than performing assertions.
Delete hacks in `@differentiable` and `@derivative` attribute type-checking
for verifying that `Differentiable.TangentVector` type witnesses are valid:
this is no longer necessary.
Robust fix for TF-521: invalid `Differentiable` conformances during
`@derivative` attribute type-checking.
Resolves SR-12793: bad interaction between `@differentiable` and `@derivative`
attribute type-checking and `Differentiable` derived conformances.
This simplifies fixing the master-next build. Upstream LLVM already
has a copy of this function, so on master-next we only need to delete
the Swift copy, reducing the potential for merge conflicts.
The differentiation transform does the following:
- Canonicalizes differentiability witnesses by filling in missing derivative
function entries.
- Canonicalizes `differentiable_function` instructions by filling in missing
derivative function operands.
- If necessary, performs automatic differentiation: generating derivative
functions for original functions.
- When encountering non-differentiability code, produces a diagnostic and
errors out.
Partially resolves TF-1211: add the main canonicalization loop.
To incrementally stage changes, derivative functions are currently created
with empty bodies that fatal error with a nice message.
Derivative emitters will be upstreamed separately.
Define type signatures and SILGen for the following builtins:
```
/// Applies the {jvp|vjp} of `f` to `arg1`, ..., `argN`.
func applyDerivative_arityN_{jvp|vjp}(f, arg1, ..., argN) -> jvp/vjp return type
/// Applies the transpose of `f` to `arg`.
func applyTranspose_arityN(f, arg) -> transpose return type
/// Makes a differentiable function from the given `original`, `jvp`, and
/// `vjp` functions.
func differentiableFunction_arityN(original, jvp, vjp)
/// Makes a linear function from the given `original` and `transpose` functions.
func linearFunction_arityN(original, transpose)
```
Add SILGen FileCheck tests for all builtins.
`@differentiable` attribute on protocol requirements and non-final class
members now produces derivative function entries in witness tables and vtables.
This enables `witness_method` and `class_method` differentiation.
Existing type-checking rules:
- Witness declarations of `@differentiable` protocol requirements must have a
`@differentiable` attribute with the same configuration (or a configuration
with superset parameter indices).
- Witness table derivative function entries are SILGen'd for `@differentiable`
witness declarations.
- Class vtable derivative function entries are SILGen'd for non-final
`@differentiable` class members.
- These derivative entries can be overridden or inherited, just like other
vtable entries.
Resolves TF-1212.
`@differentiable` attribute on protocol requirements and non-final class members
will produce derivative function entries in witness tables and vtables.
This patch adds an optional derivative function configuration
(`AutoDiffDerivativeFunctionIdentifier`) to `SILDeclRef` to represent these
derivative function entries.
Derivative function configurations consist of:
- A derivative function kind (JVP or VJP).
- Differentiability parameter indices.
Resolves TF-1209.
Enables TF-1212: upstream derivative function entries in witness tables/vtables.
Generate SIL differentiability witnesses from `@differentiable` and
`@derivative` declaration attributes.
Add SILGen utilities for:
- Emiting differentiability witnesses.
- Creating derivative function thunks, which are used as entries in
differentiability witnesses.
When users register a custom derivative function, it is necessary to create a
thunk with the expected derivative type computed from the original function's
type. This is important for consistent typing and consistent differentiability
witness entry mangling.
See `SILGenModule::getOrCreateCustomDerivativeThunk` documentation for details.
Resolves TF-1138.
Semantically, an `inout` parameter is both a parameter and a result.
`@differentiable` and `@derivative` attributes now support original functions
with one "semantic result": either a formal result or an `inout` parameter.
Derivative typing rules for functions with `inout` parameters are now defined.
The differential/pullback type of a function with `inout` differentiability
parameters also has `inout` parameters. This is ideal for performance.
Differential typing rules:
- Case 1: original function has no `inout` parameters.
- Original: `(T0, T1, ...) -> R`
- Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
- Case 2: original function has a non-wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Differential: `(T0.Tan, ...) -> T1.Tan`
- Case 3: original function has a wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
Pullback typing rules:
- Case 1: original function has no `inout` parameters.
- Original: `(T0, T1, ...) -> R`
- Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
- Case 2: original function has a non-wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Pullback: `(T1.Tan) -> (T0.Tan, ...)`
- Case 3: original function has a wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
Resolves TF-1164.
Using `DenseMap<SILAutoDiffDerivativeFunctionKey, ...>` in a file that includes
AutoDiff.h but not Types.h produced the following error:
```
error: invalid application of 'alignof' to an incomplete type 'swift::SILFunctionType'
```
Fix the error by:
- Including TypeAlignments.h in AutoDiff.h.
- Adding an alignment forward declaration for `SILFunctionType` to TypeAlignments.h.
The `differentiability_witness_function` instruction looks up a
differentiability witness function (JVP, VJP, or transpose) for a referenced
function via SIL differentiability witnesses.
Add round-trip parsing/serialization and IRGen tests.
Notes:
- Differentiability witnesses for linear functions require more support.
`differentiability_witness_function [transpose]` instructions do not yet
have IRGen.
- Nothing currently generates `differentiability_witness_function` instructions.
The differentiation transform does, but it hasn't been upstreamed yet.
Resolves TF-1141.
Add `SILFunctionType::getAutoDiffTransposeFunctionType`.
It computes the transpose `SILFucntionType` for an original `SILFunctionType`,
given:
- Linearity parameter indices
- Transpose function generic signature (optional)
- Other auxiliary parameters
Add doc comments explaining typing rules, preconditions, and other details.
Add `isTranspose` flag to `autodiff::getConstrainedDerivativeGenericSignature`.
Partially resolves TF-1125.
Unblocks TF-1141: upstream `differentiability_witness_function` instruction.
SIL differentiability witnesses are a new top-level SIL construct mapping
"original" SIL functions to derivative SIL functions.
SIL differentiability witnesses have the following components:
- "Original" `SILFunction`.
- SIL linkage.
- Differentiability parameter indices (`IndexSubset`).
- Differentiability result indices (`IndexSubset`).
- Derivative `GenericSignature` representing differentiability generic
requirements (optional).
- JVP derivative `SILFunction` (optional).
- VJP derivative `SILFunction` (optional).
- "Is serialized?" bit.
This patch adds the `SILDifferentiabilityWitness` data structure, with
documentation, parsing, and printing.
Resolves TF-911.
Todos:
- TF-1136: upstream `SILDifferentiabilityWitness` serialization.
- TF-1137: upstream `SILDifferentiabilityWitness` verification.
- TF-1138: upstream `SILDifferentiabilityWitness` SILGen from
`@differentiable` and `@derivative` attributes.
- TF-20: robust mangling for `SILDifferentiabilityWitness` names.
`TangentSpace` represents the tangent space of a type.
- For `Differentiable`-conforming types:
- The tangent space is the `TangentVector` associated type.
- For tuple types:
- The tangent space is a tuple of the elements' tangent space types, for the
elements that have a tangent space.
- Other types have no tangent space.
`TypeBase::getAutoDiffTangentSpace` gets the tangent space of a type.
`TangentSpace` is used to:
- Compute the derivative function type of a given original function type.
- Compute the type of tangent/adjoint values during automatic differentiation.
Progress towards TF-828: upstream `@differentiable` attribute type-checking.
This is needed to build with VS2019:
```
swift\include\swift\AST\AutoDiff.h(172): error C2593: 'operator ==' is ambiguous
swift\include\swift\Basic\AnyValue.h(129): note: could be 'bool swift::operator ==(const swift::AnyValue &,const swift::AnyValue &)' [found using argument-dependent lookup]
swift\include\swift\AST\AutoDiff.h(172): note: or 'built-in C++ operator==(swift::AutoDiffDerivativeFunctionKind::innerty, swift::AutoDiffDerivativeFunctionKind::innerty)'
swift\include\swift\AST\AutoDiff.h(172): note: while trying to match the argument list '(const swift::AutoDiffDerivativeFunctionKind, const swift::AutoDiffDerivativeFunctionKind)'
```
The `@derivative` attribute registers a function as a derivative of another
function-like declaration: a `func`, `init`, `subscript`, or `var` computed
property declaration.
The `@derivative` attribute also has an optional `wrt:` clause specifying the
parameters that are differentiated "with respect to", i.e. the differentiation
parameters. The differentiation parameters must conform to the `Differentiable`
protocol.
If the `wrt:` clause is unspecified, the differentiation parameters are inferred
to be all parameters that conform to `Differentiable`.
`@derivative` attribute type-checking verifies that the type of the derivative
function declaration is consistent with the type of the referenced original
declaration and the differentiation parameters.
The `@derivative` attribute is gated by the
`-enable-experimental-differentiable-programming` flag.
Resolves TF-829.
The `@derivative(of:)` attribute registers a function as a derivative of another
function. This patch adds the `@derivative(of:)` attribute definition, syntax,
parsing, and printing.
Resolves TF-826.
Todos:
- Type-checking (TF-829).
- Serialization (TF-837).
This PR introduces `@differentiable` attribute to mark functions as differentiable. This PR only contains changes related to parsing the attribute. Type checking and other changes will be added in subsequent patches.
See https://github.com/apple/swift/pull/27506/files#diff-f3216f4188fd5ed34e1007e5a9c2490f for examples and tests for the new attribute.