Merge remote-tracking branch 'origin/main' into rebranch

This commit is contained in:
swift-ci
2025-05-25 06:53:22 -07:00
30 changed files with 476 additions and 31 deletions

View File

@@ -4387,6 +4387,9 @@ NOTE(derivative_attr_fix_access,none,
"mark the derivative function as "
"'%select{private|fileprivate|internal|package|@usableFromInline|@usableFromInline}0' "
"to match the original function", (AccessLevel))
ERROR(derivative_attr_always_emit_into_client_mismatch,none,
"either both or none of derivative and original function must have "
"@alwaysEmitIntoClient attribute", ())
ERROR(derivative_attr_static_method_mismatch_original,none,
"unexpected derivative function declaration; "
"%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method",

View File

@@ -159,8 +159,22 @@ void SILLinkerVisitor::maybeAddFunctionToWorklist(
// HiddenExternal linkage when they are declarations, then they
// become Shared after the body has been deserialized.
// So try deserializing HiddenExternal functions too.
if (linkage == SILLinkage::HiddenExternal)
return deserializeAndPushToWorklist(F);
if (linkage == SILLinkage::HiddenExternal) {
deserializeAndPushToWorklist(F);
if (!F->markedAsAlwaysEmitIntoClient())
return;
// For @_alwaysEmitIntoClient functions, we need to lookup its
// differentiability witness and, if present, ask SILLoader to obtain its
// definition. Otherwise, a linker error would occur due to undefined
// reference to these symbols.
for (SILDifferentiabilityWitness *witness :
F->getModule().lookUpDifferentiabilityWitnessesForFunction(
F->getName())) {
F->getModule().getSILLoader()->lookupDifferentiabilityWitness(
witness->getKey());
}
return;
}
// Update the linkage of the function in case it's different in the serialized
// SIL than derived from the AST. This can be the case with cross-module-

View File

@@ -1435,14 +1435,19 @@ void SILGenModule::emitDifferentiabilityWitness(
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
if (!diffWitness) {
// Differentiability witnesses have the same linkage as the original
// function, stripping external.
auto linkage = stripExternalFromLinkage(originalFunction->getLinkage());
// function, stripping external. For @_alwaysEmitIntoClient original
// functions, force PublicNonABI linkage of the differentiability witness so
// we can serialize it (the original function itself might be HiddenExternal
// in this case if we only have declaration without definition).
auto linkage =
originalFunction->markedAsAlwaysEmitIntoClient()
? SILLinkage::PublicNonABI
: stripExternalFromLinkage(originalFunction->getLinkage());
diffWitness = SILDifferentiabilityWitness::createDefinition(
M, linkage, originalFunction, diffKind, silConfig.parameterIndices,
silConfig.resultIndices, config.derivativeGenericSignature,
/*jvp*/ nullptr, /*vjp*/ nullptr,
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
attr);
/*isSerialized*/ hasPublicVisibility(linkage), attr);
}
// Set derivative function in differentiability witness.

View File

@@ -6498,8 +6498,14 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
auto loc = customDerivativeFn->getLocation();
SILGenFunctionBuilder fb(*this);
// Derivative thunks have the same linkage as the original function, stripping
// external.
auto linkage = stripExternalFromLinkage(originalFn->getLinkage());
// external. For @_alwaysEmitIntoClient original functions, force PublicNonABI
// linkage of derivative thunks so we can serialize them (the original
// function itself might be HiddenExternal in this case if we only have
// declaration without definition).
auto linkage = originalFn->markedAsAlwaysEmitIntoClient()
? SILLinkage::PublicNonABI
: stripExternalFromLinkage(originalFn->getLinkage());
auto *thunk = fb.getOrCreateFunction(
loc, name, linkage, thunkFnTy, IsBare, IsNotTransparent,
customDerivativeFn->getSerializedKind(),

View File

@@ -538,9 +538,14 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
"definitions with explicit differentiable attributes");
return SILDifferentiabilityWitness::createDeclaration(
module, SILLinkage::PublicExternal, original, kind,
minimalConfig->parameterIndices, minimalConfig->resultIndices,
minimalConfig->derivativeGenericSignature);
module,
// Witness for @_alwaysEmitIntoClient original function must be emitted,
// otherwise a linker error would occur due to undefined reference to the
// witness symbol.
original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI
: SILLinkage::PublicExternal,
original, kind, minimalConfig->parameterIndices,
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
}
} // end namespace autodiff

View File

@@ -999,10 +999,14 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
// We can generate empty JVP / VJP for functions available externally. These
// functions have the same linkage as the original ones sans `external`
// flag. Important exception here hidden_external functions as they are
// serializable but corresponding hidden ones would be not and the SIL
// verifier will fail. Patch `serializeFunctions` for this case.
if (orig->getLinkage() == SILLinkage::HiddenExternal)
// flag. Important exception here hidden_external non-@_alwaysEmitIntoClient
// functions as they are serializable but corresponding hidden ones would be
// not and the SIL verifier will fail. Patch `serializeFunctions` for this
// case. For @_alwaysEmitIntoClient original functions (which might be
// HiddenExternal if we only have declaration without definition), we want
// derivatives to be serialized and do not patch `serializeFunctions`.
if (orig->getLinkage() == SILLinkage::HiddenExternal &&
!orig->markedAsAlwaysEmitIntoClient())
serializeFunctions = IsNotSerialized;
// If the JVP doesn't exist, need to synthesize it.

View File

@@ -6990,6 +6990,13 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
return true;
}
if (originalAFD->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>() !=
derivative->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>()) {
diags.diagnose(derivative->getLoc(),
diag::derivative_attr_always_emit_into_client_mismatch);
return true;
}
// Get the resolved differentiability parameter indices.
auto *resolvedDiffParamIndices = attr->getParameterIndices();

View File

@@ -405,9 +405,6 @@ where
}
}
// FIXME(TF-1103): Derivative registration does not yet support
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
/*
extension SIMD
where
Self: Differentiable,
@@ -417,6 +414,7 @@ where
TangentVector == Self
{
@inlinable
@_alwaysEmitIntoClient
@derivative(of: sum)
func _vjpSum() -> (
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
@@ -425,6 +423,7 @@ where
}
@inlinable
@_alwaysEmitIntoClient
@derivative(of: sum)
func _jvpSum() -> (
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
@@ -432,7 +431,6 @@ where
return (sum(), { v in Scalar.TangentVector(v.sum()) })
}
}
*/
extension SIMD
where

View File

@@ -1,8 +1,9 @@
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify %s | %FileCheck %s
/// Note: -primary-file prevents non_abi->shared linkage change in `removeSerializedFlagFromAllFunctions`
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify -primary-file %s | %FileCheck %s
import _Differentiation
// CHECK: sil @test_nil_coalescing
// CHECK: sil non_abi @test_nil_coalescing
// CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional<T>, %[[ARG_PB:.*]] :
// CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>):
// CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional<T>
@@ -15,7 +16,7 @@ import _Differentiation
//
@_silgen_name("test_nil_coalescing")
@derivative(of: ??)
@usableFromInline
@_alwaysEmitIntoClient
func nilCoalescing<T: Differentiable>(optional: T?, defaultValue: @autoclosure () throws -> T)
rethrows -> (value: T, pullback: (T.TangentVector) -> Optional<T>.TangentVector)
{

View File

@@ -1062,6 +1062,15 @@ func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullb
fatalError()
}
func internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
@_alwaysEmitIntoClient
@derivative(of: internal_original_alwaysemitintoclient_derivative_error)
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
func _internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}
@_alwaysEmitIntoClient
func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
@_alwaysEmitIntoClient
@derivative(of: internal_original_alwaysemitintoclient_derivative)
@@ -1084,6 +1093,15 @@ package func _package_original_inlinable_derivative(_ x: Float) -> (value: Float
fatalError()
}
@_alwaysEmitIntoClient
package func package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
@derivative(of: package_original_alwaysemitintoclient_derivative_error)
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
package func _package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}
@_alwaysEmitIntoClient
package func package_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
@_alwaysEmitIntoClient
@derivative(of: package_original_alwaysemitintoclient_derivative)

View File

@@ -19,9 +19,6 @@ SIMDTests.test("init(repeating:)") {
expectEqual(8, pb1(g))
}
// FIXME(TF-1103): Derivative registration does not yet support
// `@_alwaysEmitIntoClient` original functions.
/*
SIMDTests.test("Sum") {
let a = SIMD4<Float>(1, 2, 3, 4)
@@ -32,7 +29,6 @@ SIMDTests.test("Sum") {
expectEqual(10, val1)
expectEqual(SIMD4<Float>(3, 3, 3, 3), pb1(3))
}
*/
SIMDTests.test("Identity") {
let a = SIMD4<Float>(1, 2, 3, 4)
@@ -289,9 +285,6 @@ SIMDTests.test("Generics") {
expectEqual(SIMD3<Double>(5, 10, 15), val4)
expectEqual((SIMD3<Double>(5, 5, 5), 6), pb4(g))
// FIXME(TF-1103): Derivative registration does not yet support
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
/*
func testSum<Scalar, SIMDType: SIMD>(x: SIMDType) -> Scalar
where SIMDType.Scalar == Scalar,
SIMDType : Differentiable,
@@ -304,7 +297,6 @@ SIMDTests.test("Generics") {
let (val5, pb5) = valueWithPullback(at: a, of: simd3Sum)
expectEqual(6, val5)
expectEqual(SIMD3<Double>(7, 7, 7), pb5(7))
*/
}
runAllTests()

View File

@@ -0,0 +1,4 @@
@_alwaysEmitIntoClient
public func f(_ x: Float) -> Float {
x
}

View File

@@ -0,0 +1,7 @@
import _Differentiation
@derivative(of: f)
@_alwaysEmitIntoClient
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { 42 * $0 })
}

View File

@@ -0,0 +1,4 @@
@_alwaysEmitIntoClient
public func f(_ x: Float) -> Float {
x
}

View File

@@ -0,0 +1,8 @@
import MultiModule1
import _Differentiation
@derivative(of: f)
@_alwaysEmitIntoClient
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { 42 * $0 })
}

View File

@@ -0,0 +1,16 @@
import _Differentiation
public protocol Protocol {
var x : Float {get set}
init()
}
extension Protocol {
public init(_ val: Float) {
self.init()
x = val
}
@_alwaysEmitIntoClient
public func sum() -> Float { x }
}

View File

@@ -0,0 +1,20 @@
import MultiModuleProtocol1
import _Differentiation
extension Protocol where Self: Differentiable, Self.TangentVector == Self {
@_alwaysEmitIntoClient
@derivative(of: sum)
public func _vjpSum() -> (
value: Float, pullback: (Float) -> Self.TangentVector
) {
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
}
@_alwaysEmitIntoClient
@derivative(of: sum)
public func _jvpSum() -> (
value: Float, differential: (Self.TangentVector) -> Float
) {
(value: self.x, differential: { 42 * $0.x })
}
}

View File

@@ -0,0 +1,24 @@
import MultiModuleProtocol1
import MultiModuleProtocol2
import _Differentiation
public struct Struct : Protocol {
private var _x : Float
public var x : Float {
get { _x }
set { _x = newValue }
}
public init() { _x = 0 }
}
extension Struct : AdditiveArithmetic {
public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) }
public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) }
public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x }
public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x }
public static var zero: Self { Self(0) }
}
extension Struct : Differentiable {
public typealias TangentVector = Self
}

View File

@@ -0,0 +1,23 @@
public struct Struct {
public var x : Float
public typealias TangentVector = Self
public init() { x = 0 }
}
extension Struct {
public init(_ val: Float) {
self.init()
x = val
}
@_alwaysEmitIntoClient
public func sum() -> Float { x }
}
extension Struct : AdditiveArithmetic {
public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) }
public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) }
public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x }
public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x }
public static var zero: Self { Self(0) }
}

View File

@@ -0,0 +1,20 @@
import MultiModuleStruct1
import _Differentiation
extension Struct : Differentiable {
@_alwaysEmitIntoClient
@derivative(of: sum)
public func _vjpSum() -> (
value: Float, pullback: (Float) -> Self.TangentVector
) {
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
}
@_alwaysEmitIntoClient
@derivative(of: sum)
public func _jvpSum() -> (
value: Float, differential: (Self.TangentVector) -> Float
) {
(value: self.x, differential: { 42 * $0.x })
}
}

View File

@@ -0,0 +1,12 @@
import MultiModuleStruct1
import _Differentiation
extension Struct : Differentiable {
@_alwaysEmitIntoClient
@derivative(of: sum)
public func _vjpSum() -> (
value: Float, pullback: (Float) -> Self.TangentVector
) {
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
}
}

View File

@@ -0,0 +1,12 @@
import MultiModuleStruct1
import _Differentiation
extension Struct : Differentiable {
@_alwaysEmitIntoClient
@derivative(of: sum)
public func _jvpSum() -> (
value: Float, differential: (Self.TangentVector) -> Float
) {
(value: self.x, differential: { 42 * $0.x })
}
}

View File

@@ -0,0 +1,12 @@
import _Differentiation
@_alwaysEmitIntoClient
public func f(_ x: Float) -> Float {
x
}
@derivative(of: f)
@_alwaysEmitIntoClient
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { 42 * $0 })
}

View File

@@ -0,0 +1,28 @@
// REQUIRES: executable_test
// RUN: %empty-directory(%t)
/// Note: we build just a module without a library since it would not contain any exported
/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient.
// RUN: %target-build-swift %S/Inputs/MultiFileModule/file1.swift %S/Inputs/MultiFileModule/file2.swift \
// RUN: -emit-module -emit-module-path %t/MultiFileModule.swiftmodule -module-name MultiFileModule
// RUN: %target-build-swift -I%t %s -o %t/a.out %target-rpath(%t)
// RUN: %target-run %t/a.out
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
import MultiFileModule
import StdlibUnittest
import _Differentiation
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
AlwaysEmitIntoClientTests.test("registration") {
expectEqual(42, gradient(at: 0, of: f))
expectEqual(42, gradient(at: 1, of: f))
expectEqual(42, gradient(at: 2, of: f))
}
runAllTests()
// CHECK: @"15MultiFileModule1fyS2fFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s15MultiFileModule1fyS2fFTJfSpSr", ptr @"$s15MultiFileModule1fyS2fFTJrSpSr" }

View File

@@ -0,0 +1,31 @@
// REQUIRES: executable_test
// RUN: %empty-directory(%t)
/// Note: we build just modules without libraries since they would not contain any exported
/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient.
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModule1)) %S/Inputs/MultiModule/file1.swift \
// RUN: -emit-module -emit-module-path %t/MultiModule1.swiftmodule -module-name MultiModule1
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModule2)) %S/Inputs/MultiModule/file2.swift \
// RUN: -emit-module -emit-module-path %t/MultiModule2.swiftmodule -module-name MultiModule2 -I%t %target-rpath(%t)
// RUN: %target-build-swift -I%t %s -o %t/a.out %target-rpath(%t)
// RUN: %target-run %t/a.out
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
import MultiModule1
import MultiModule2
import StdlibUnittest
import _Differentiation
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
AlwaysEmitIntoClientTests.test("registration") {
expectEqual(42, gradient(at: 0, of: f))
expectEqual(42, gradient(at: 1, of: f))
expectEqual(42, gradient(at: 2, of: f))
}
runAllTests()
// CHECK: @"12MultiModule11fyS2fFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s12MultiModule11fyS2fFTJfSpSr", ptr @"$s12MultiModule11fyS2fFTJrSpSr" }

View File

@@ -0,0 +1,44 @@
// REQUIRES: executable_test
// RUN: %empty-directory(%t)
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleProtocol1)) %S/Inputs/MultiModuleProtocol/file1.swift \
// RUN: -emit-module -emit-module-path %t/MultiModuleProtocol1.swiftmodule -module-name MultiModuleProtocol1
/// Note: we build just a module without a library since it would not contain any exported
/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient.
// RUN: %target-build-swift %S/Inputs/MultiModuleProtocol/file2.swift -emit-module -emit-module-path %t/MultiModuleProtocol2.swiftmodule \
// RUN: -module-name MultiModuleProtocol2 -I%t -lMultiModuleProtocol1 %target-rpath(%t)
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleProtocol3)) %S/Inputs/MultiModuleProtocol/file3.swift \
// RUN: -emit-module -emit-module-path %t/MultiModuleProtocol3.swiftmodule -module-name MultiModuleProtocol3 -I%t -L%t -lMultiModuleProtocol1 %target-rpath(%t)
/// Note: we enable forward-mode differentiation to automatically generate JVP for `foo`.
/// It wraps `Protocol.sum` that has custom JVP defined in MultiModuleProtocol2, so we can test it.
// RUN: %target-build-swift -Xfrontend -enable-experimental-forward-mode-differentiation \
// RUN: -I%t -L%t %s -lMultiModuleProtocol1 -lMultiModuleProtocol3 -o %t/a.out %target-rpath(%t)
// RUN: %target-run %t/a.out
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
import MultiModuleProtocol1
import MultiModuleProtocol2
import MultiModuleProtocol3
import StdlibUnittest
import _Differentiation
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
AlwaysEmitIntoClientTests.test("registration") {
func foo<T: Protocol>(x: T) -> Float
where T: Differentiable, T.TangentVector == T { x.sum() }
expectEqual(Struct(42), pullback(at: Struct(0), of: foo)(1))
expectEqual(Struct(42), pullback(at: Struct(1), of: foo)(1))
expectEqual(Struct(42), pullback(at: Struct(2), of: foo)(1))
expectEqual(42, differential(at: Struct(0), of: foo)(Struct(1)))
expectEqual(42, differential(at: Struct(1), of: foo)(Struct(1)))
expectEqual(42, differential(at: Struct(2), of: foo)(Struct(1)))
}
runAllTests()
// CHECK: @"20MultiModuleProtocol18ProtocolPAAE3sumSfyFAaBRz16_Differentiation14DifferentiableRz13TangentVectorAeFPQzRszlWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s20MultiModuleProtocol18ProtocolPAAE3sumSfyFAaBRz16_Differentiation14DifferentiableRz13TangentVectorAeFPQzRszlTJfSpSr", ptr @"$s20MultiModuleProtocol18ProtocolPAAE3sumSfyFAaBRz16_Differentiation14DifferentiableRz13TangentVectorAeFPQzRszlTJrSpSr" }

View File

@@ -0,0 +1,36 @@
// REQUIRES: executable_test
// RUN: %empty-directory(%t)
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct1)) %S/Inputs/MultiModuleStruct/file1.swift \
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct1.swiftmodule -module-name MultiModuleStruct1
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct2)) %S/Inputs/MultiModuleStruct/file2.swift \
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct2.swiftmodule -module-name MultiModuleStruct2 -I%t -L%t -lMultiModuleStruct1 %target-rpath(%t)
/// Note: we enable forward-mode differentiation to automatically generate JVP for `foo`.
/// It wraps `Struct.sum` that has custom JVP defined in MultiModuleStruct2, so we can test it.
// RUN: %target-build-swift -Xfrontend -enable-experimental-forward-mode-differentiation \
// RUN: -I%t -L%t %s -lMultiModuleStruct1 -lMultiModuleStruct2 -o %t/a.out %target-rpath(%t)
// RUN: %target-run %t/a.out
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
import MultiModuleStruct1
import MultiModuleStruct2
import StdlibUnittest
import _Differentiation
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
AlwaysEmitIntoClientTests.test("registration") {
func foo(x: Struct) -> Float { x.sum() }
expectEqual(Struct(42), pullback(at: Struct(0), of: foo)(1))
expectEqual(Struct(42), pullback(at: Struct(1), of: foo)(1))
expectEqual(Struct(42), pullback(at: Struct(2), of: foo)(1))
expectEqual(42, differential(at: Struct(0), of: foo)(Struct(1)))
expectEqual(42, differential(at: Struct(1), of: foo)(Struct(1)))
expectEqual(42, differential(at: Struct(2), of: foo)(Struct(1)))
}
runAllTests()
// CHECK: @"18MultiModuleStruct16StructV3sumSfyFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJfSpSr", ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJrSpSr" }

View File

@@ -0,0 +1,37 @@
// REQUIRES: executable_test
// RUN: %empty-directory(%t)
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct1)) %S/Inputs/MultiModuleStruct/file1.swift \
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct1.swiftmodule -module-name MultiModuleStruct1
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct2NoJVP)) %S/Inputs/MultiModuleStruct/file2_no_jvp.swift \
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct2NoJVP.swiftmodule -module-name MultiModuleStruct2NoJVP -I%t -L%t -lMultiModuleStruct1 %target-rpath(%t)
/// Note: we enable forward-mode differentiation to automatically generate JVP for `foo`.
/// It wraps `Struct.sum` that has custom JVP defined in MultiModuleStruct2, so we can test it.
// RUN: %target-build-swift -Xfrontend -enable-experimental-forward-mode-differentiation \
// RUN: -I%t -L%t %s -lMultiModuleStruct1 -lMultiModuleStruct2NoJVP -o %t/a.out %target-rpath(%t)
// RUN: %target-run %t/a.out
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
import MultiModuleStruct1
import MultiModuleStruct2NoJVP
import StdlibUnittest
import _Differentiation
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
AlwaysEmitIntoClientTests.test("registration") {
func foo(x: Struct) -> Float { x.sum() }
expectEqual(Struct(42), pullback(at: Struct(0), of: foo)(1))
expectEqual(Struct(42), pullback(at: Struct(1), of: foo)(1))
expectEqual(Struct(42), pullback(at: Struct(2), of: foo)(1))
/// Custom JVP for Struct.sum is not provided, a JVP causing fatal error is emitted.
expectCrash{differential(at: Struct(0), of: foo)(Struct(1))}
expectCrash{differential(at: Struct(1), of: foo)(Struct(1))}
expectCrash{differential(at: Struct(2), of: foo)(Struct(1))}
}
runAllTests()
// CHECK: @"18MultiModuleStruct16StructV3sumSfyFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJfSpSr", ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJrSpSr" }

View File

@@ -0,0 +1,26 @@
// REQUIRES: executable_test
// RUN: %empty-directory(%t)
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct1)) %S/Inputs/MultiModuleStruct/file1.swift \
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct1.swiftmodule -module-name MultiModuleStruct1
// RUN: not %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct2NoVJP)) %S/Inputs/MultiModuleStruct/file2_no_vjp.swift \
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct2NoVJP.swiftmodule -module-name MultiModuleStruct2NoVJP -I%t -L%t -lMultiModuleStruct1 %target-rpath(%t) 2>&1 | \
// RUN: %FileCheck %s
// CHECK: file2_no_vjp.swift:6:4: error: function is not differentiable
import MultiModuleStruct1
import MultiModuleStruct2
import StdlibUnittest
import _Differentiation
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
AlwaysEmitIntoClientTests.test("registration") {
func foo(x: Struct) -> Float { x.sum() }
expectEqual(42, differential(at: Struct(0), of: foo)(Struct(1)))
expectEqual(42, differential(at: Struct(1), of: foo)(Struct(1)))
expectEqual(42, differential(at: Struct(2), of: foo)(Struct(1)))
}
runAllTests()

View File

@@ -0,0 +1,28 @@
// REQUIRES: executable_test
// RUN: %empty-directory(%t)
/// Note: we build just a module without a library since it would not contain any exported
/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient.
// RUN: %target-build-swift %S/Inputs/SingleFileModule/file.swift -emit-module \
// RUN: -emit-module-path %t/SingleFileModule.swiftmodule -module-name SingleFileModule
// RUN: %target-build-swift -I%t %s -o %t/a.out %target-rpath(%t)
// RUN: %target-run %t/a.out
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
import SingleFileModule
import StdlibUnittest
import _Differentiation
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
AlwaysEmitIntoClientTests.test("registration") {
expectEqual(42, gradient(at: 0, of: f))
expectEqual(42, gradient(at: 1, of: f))
expectEqual(42, gradient(at: 2, of: f))
}
runAllTests()
// CHECK: @"16SingleFileModule1fyS2fFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s16SingleFileModule1fyS2fFTJfSpSr", ptr @"$s16SingleFileModule1fyS2fFTJrSpSr" }