mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Merge remote-tracking branch 'origin/main' into rebranch
This commit is contained in:
@@ -4387,6 +4387,9 @@ NOTE(derivative_attr_fix_access,none,
|
||||
"mark the derivative function as "
|
||||
"'%select{private|fileprivate|internal|package|@usableFromInline|@usableFromInline}0' "
|
||||
"to match the original function", (AccessLevel))
|
||||
ERROR(derivative_attr_always_emit_into_client_mismatch,none,
|
||||
"either both or none of derivative and original function must have "
|
||||
"@alwaysEmitIntoClient attribute", ())
|
||||
ERROR(derivative_attr_static_method_mismatch_original,none,
|
||||
"unexpected derivative function declaration; "
|
||||
"%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method",
|
||||
|
||||
@@ -159,8 +159,22 @@ void SILLinkerVisitor::maybeAddFunctionToWorklist(
|
||||
// HiddenExternal linkage when they are declarations, then they
|
||||
// become Shared after the body has been deserialized.
|
||||
// So try deserializing HiddenExternal functions too.
|
||||
if (linkage == SILLinkage::HiddenExternal)
|
||||
return deserializeAndPushToWorklist(F);
|
||||
if (linkage == SILLinkage::HiddenExternal) {
|
||||
deserializeAndPushToWorklist(F);
|
||||
if (!F->markedAsAlwaysEmitIntoClient())
|
||||
return;
|
||||
// For @_alwaysEmitIntoClient functions, we need to lookup its
|
||||
// differentiability witness and, if present, ask SILLoader to obtain its
|
||||
// definition. Otherwise, a linker error would occur due to undefined
|
||||
// reference to these symbols.
|
||||
for (SILDifferentiabilityWitness *witness :
|
||||
F->getModule().lookUpDifferentiabilityWitnessesForFunction(
|
||||
F->getName())) {
|
||||
F->getModule().getSILLoader()->lookupDifferentiabilityWitness(
|
||||
witness->getKey());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Update the linkage of the function in case it's different in the serialized
|
||||
// SIL than derived from the AST. This can be the case with cross-module-
|
||||
|
||||
@@ -1435,14 +1435,19 @@ void SILGenModule::emitDifferentiabilityWitness(
|
||||
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
|
||||
if (!diffWitness) {
|
||||
// Differentiability witnesses have the same linkage as the original
|
||||
// function, stripping external.
|
||||
auto linkage = stripExternalFromLinkage(originalFunction->getLinkage());
|
||||
// function, stripping external. For @_alwaysEmitIntoClient original
|
||||
// functions, force PublicNonABI linkage of the differentiability witness so
|
||||
// we can serialize it (the original function itself might be HiddenExternal
|
||||
// in this case if we only have declaration without definition).
|
||||
auto linkage =
|
||||
originalFunction->markedAsAlwaysEmitIntoClient()
|
||||
? SILLinkage::PublicNonABI
|
||||
: stripExternalFromLinkage(originalFunction->getLinkage());
|
||||
diffWitness = SILDifferentiabilityWitness::createDefinition(
|
||||
M, linkage, originalFunction, diffKind, silConfig.parameterIndices,
|
||||
silConfig.resultIndices, config.derivativeGenericSignature,
|
||||
/*jvp*/ nullptr, /*vjp*/ nullptr,
|
||||
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
|
||||
attr);
|
||||
/*isSerialized*/ hasPublicVisibility(linkage), attr);
|
||||
}
|
||||
|
||||
// Set derivative function in differentiability witness.
|
||||
|
||||
@@ -6498,8 +6498,14 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
|
||||
auto loc = customDerivativeFn->getLocation();
|
||||
SILGenFunctionBuilder fb(*this);
|
||||
// Derivative thunks have the same linkage as the original function, stripping
|
||||
// external.
|
||||
auto linkage = stripExternalFromLinkage(originalFn->getLinkage());
|
||||
// external. For @_alwaysEmitIntoClient original functions, force PublicNonABI
|
||||
// linkage of derivative thunks so we can serialize them (the original
|
||||
// function itself might be HiddenExternal in this case if we only have
|
||||
// declaration without definition).
|
||||
auto linkage = originalFn->markedAsAlwaysEmitIntoClient()
|
||||
? SILLinkage::PublicNonABI
|
||||
: stripExternalFromLinkage(originalFn->getLinkage());
|
||||
|
||||
auto *thunk = fb.getOrCreateFunction(
|
||||
loc, name, linkage, thunkFnTy, IsBare, IsNotTransparent,
|
||||
customDerivativeFn->getSerializedKind(),
|
||||
|
||||
@@ -538,9 +538,14 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
|
||||
"definitions with explicit differentiable attributes");
|
||||
|
||||
return SILDifferentiabilityWitness::createDeclaration(
|
||||
module, SILLinkage::PublicExternal, original, kind,
|
||||
minimalConfig->parameterIndices, minimalConfig->resultIndices,
|
||||
minimalConfig->derivativeGenericSignature);
|
||||
module,
|
||||
// Witness for @_alwaysEmitIntoClient original function must be emitted,
|
||||
// otherwise a linker error would occur due to undefined reference to the
|
||||
// witness symbol.
|
||||
original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI
|
||||
: SILLinkage::PublicExternal,
|
||||
original, kind, minimalConfig->parameterIndices,
|
||||
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
|
||||
}
|
||||
|
||||
} // end namespace autodiff
|
||||
|
||||
@@ -999,10 +999,14 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
|
||||
|
||||
// We can generate empty JVP / VJP for functions available externally. These
|
||||
// functions have the same linkage as the original ones sans `external`
|
||||
// flag. Important exception here hidden_external functions as they are
|
||||
// serializable but corresponding hidden ones would be not and the SIL
|
||||
// verifier will fail. Patch `serializeFunctions` for this case.
|
||||
if (orig->getLinkage() == SILLinkage::HiddenExternal)
|
||||
// flag. Important exception here hidden_external non-@_alwaysEmitIntoClient
|
||||
// functions as they are serializable but corresponding hidden ones would be
|
||||
// not and the SIL verifier will fail. Patch `serializeFunctions` for this
|
||||
// case. For @_alwaysEmitIntoClient original functions (which might be
|
||||
// HiddenExternal if we only have declaration without definition), we want
|
||||
// derivatives to be serialized and do not patch `serializeFunctions`.
|
||||
if (orig->getLinkage() == SILLinkage::HiddenExternal &&
|
||||
!orig->markedAsAlwaysEmitIntoClient())
|
||||
serializeFunctions = IsNotSerialized;
|
||||
|
||||
// If the JVP doesn't exist, need to synthesize it.
|
||||
|
||||
@@ -6990,6 +6990,13 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (originalAFD->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>() !=
|
||||
derivative->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>()) {
|
||||
diags.diagnose(derivative->getLoc(),
|
||||
diag::derivative_attr_always_emit_into_client_mismatch);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Get the resolved differentiability parameter indices.
|
||||
auto *resolvedDiffParamIndices = attr->getParameterIndices();
|
||||
|
||||
|
||||
@@ -405,9 +405,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME(TF-1103): Derivative registration does not yet support
|
||||
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
|
||||
/*
|
||||
extension SIMD
|
||||
where
|
||||
Self: Differentiable,
|
||||
@@ -417,6 +414,7 @@ where
|
||||
TangentVector == Self
|
||||
{
|
||||
@inlinable
|
||||
@_alwaysEmitIntoClient
|
||||
@derivative(of: sum)
|
||||
func _vjpSum() -> (
|
||||
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
|
||||
@@ -425,6 +423,7 @@ where
|
||||
}
|
||||
|
||||
@inlinable
|
||||
@_alwaysEmitIntoClient
|
||||
@derivative(of: sum)
|
||||
func _jvpSum() -> (
|
||||
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
|
||||
@@ -432,7 +431,6 @@ where
|
||||
return (sum(), { v in Scalar.TangentVector(v.sum()) })
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
extension SIMD
|
||||
where
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify %s | %FileCheck %s
|
||||
/// Note: -primary-file prevents non_abi->shared linkage change in `removeSerializedFlagFromAllFunctions`
|
||||
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify -primary-file %s | %FileCheck %s
|
||||
|
||||
import _Differentiation
|
||||
|
||||
// CHECK: sil @test_nil_coalescing
|
||||
// CHECK: sil non_abi @test_nil_coalescing
|
||||
// CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional<T>, %[[ARG_PB:.*]] :
|
||||
// CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>):
|
||||
// CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional<T>
|
||||
@@ -15,7 +16,7 @@ import _Differentiation
|
||||
//
|
||||
@_silgen_name("test_nil_coalescing")
|
||||
@derivative(of: ??)
|
||||
@usableFromInline
|
||||
@_alwaysEmitIntoClient
|
||||
func nilCoalescing<T: Differentiable>(optional: T?, defaultValue: @autoclosure () throws -> T)
|
||||
rethrows -> (value: T, pullback: (T.TangentVector) -> Optional<T>.TangentVector)
|
||||
{
|
||||
|
||||
@@ -1062,6 +1062,15 @@ func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullb
|
||||
fatalError()
|
||||
}
|
||||
|
||||
func internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
|
||||
@_alwaysEmitIntoClient
|
||||
@derivative(of: internal_original_alwaysemitintoclient_derivative_error)
|
||||
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
|
||||
func _internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
|
||||
fatalError()
|
||||
}
|
||||
|
||||
@_alwaysEmitIntoClient
|
||||
func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
|
||||
@_alwaysEmitIntoClient
|
||||
@derivative(of: internal_original_alwaysemitintoclient_derivative)
|
||||
@@ -1084,6 +1093,15 @@ package func _package_original_inlinable_derivative(_ x: Float) -> (value: Float
|
||||
fatalError()
|
||||
}
|
||||
|
||||
@_alwaysEmitIntoClient
|
||||
package func package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
|
||||
@derivative(of: package_original_alwaysemitintoclient_derivative_error)
|
||||
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
|
||||
package func _package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
|
||||
fatalError()
|
||||
}
|
||||
|
||||
@_alwaysEmitIntoClient
|
||||
package func package_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
|
||||
@_alwaysEmitIntoClient
|
||||
@derivative(of: package_original_alwaysemitintoclient_derivative)
|
||||
|
||||
@@ -19,9 +19,6 @@ SIMDTests.test("init(repeating:)") {
|
||||
expectEqual(8, pb1(g))
|
||||
}
|
||||
|
||||
// FIXME(TF-1103): Derivative registration does not yet support
|
||||
// `@_alwaysEmitIntoClient` original functions.
|
||||
/*
|
||||
SIMDTests.test("Sum") {
|
||||
let a = SIMD4<Float>(1, 2, 3, 4)
|
||||
|
||||
@@ -32,7 +29,6 @@ SIMDTests.test("Sum") {
|
||||
expectEqual(10, val1)
|
||||
expectEqual(SIMD4<Float>(3, 3, 3, 3), pb1(3))
|
||||
}
|
||||
*/
|
||||
|
||||
SIMDTests.test("Identity") {
|
||||
let a = SIMD4<Float>(1, 2, 3, 4)
|
||||
@@ -289,9 +285,6 @@ SIMDTests.test("Generics") {
|
||||
expectEqual(SIMD3<Double>(5, 10, 15), val4)
|
||||
expectEqual((SIMD3<Double>(5, 5, 5), 6), pb4(g))
|
||||
|
||||
// FIXME(TF-1103): Derivative registration does not yet support
|
||||
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
|
||||
/*
|
||||
func testSum<Scalar, SIMDType: SIMD>(x: SIMDType) -> Scalar
|
||||
where SIMDType.Scalar == Scalar,
|
||||
SIMDType : Differentiable,
|
||||
@@ -304,7 +297,6 @@ SIMDTests.test("Generics") {
|
||||
let (val5, pb5) = valueWithPullback(at: a, of: simd3Sum)
|
||||
expectEqual(6, val5)
|
||||
expectEqual(SIMD3<Double>(7, 7, 7), pb5(7))
|
||||
*/
|
||||
}
|
||||
|
||||
runAllTests()
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
@_alwaysEmitIntoClient
|
||||
public func f(_ x: Float) -> Float {
|
||||
x
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
import _Differentiation
|
||||
|
||||
@derivative(of: f)
|
||||
@_alwaysEmitIntoClient
|
||||
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
|
||||
(x, { 42 * $0 })
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
@_alwaysEmitIntoClient
|
||||
public func f(_ x: Float) -> Float {
|
||||
x
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
import MultiModule1
|
||||
import _Differentiation
|
||||
|
||||
@derivative(of: f)
|
||||
@_alwaysEmitIntoClient
|
||||
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
|
||||
(x, { 42 * $0 })
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
import _Differentiation
|
||||
|
||||
public protocol Protocol {
|
||||
var x : Float {get set}
|
||||
init()
|
||||
}
|
||||
|
||||
extension Protocol {
|
||||
public init(_ val: Float) {
|
||||
self.init()
|
||||
x = val
|
||||
}
|
||||
|
||||
@_alwaysEmitIntoClient
|
||||
public func sum() -> Float { x }
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
import MultiModuleProtocol1
|
||||
import _Differentiation
|
||||
|
||||
extension Protocol where Self: Differentiable, Self.TangentVector == Self {
|
||||
@_alwaysEmitIntoClient
|
||||
@derivative(of: sum)
|
||||
public func _vjpSum() -> (
|
||||
value: Float, pullback: (Float) -> Self.TangentVector
|
||||
) {
|
||||
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
|
||||
}
|
||||
|
||||
@_alwaysEmitIntoClient
|
||||
@derivative(of: sum)
|
||||
public func _jvpSum() -> (
|
||||
value: Float, differential: (Self.TangentVector) -> Float
|
||||
) {
|
||||
(value: self.x, differential: { 42 * $0.x })
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
import MultiModuleProtocol1
|
||||
import MultiModuleProtocol2
|
||||
import _Differentiation
|
||||
|
||||
public struct Struct : Protocol {
|
||||
private var _x : Float
|
||||
public var x : Float {
|
||||
get { _x }
|
||||
set { _x = newValue }
|
||||
}
|
||||
public init() { _x = 0 }
|
||||
}
|
||||
|
||||
extension Struct : AdditiveArithmetic {
|
||||
public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) }
|
||||
public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) }
|
||||
public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x }
|
||||
public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x }
|
||||
public static var zero: Self { Self(0) }
|
||||
}
|
||||
|
||||
extension Struct : Differentiable {
|
||||
public typealias TangentVector = Self
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
public struct Struct {
|
||||
public var x : Float
|
||||
public typealias TangentVector = Self
|
||||
public init() { x = 0 }
|
||||
}
|
||||
|
||||
extension Struct {
|
||||
public init(_ val: Float) {
|
||||
self.init()
|
||||
x = val
|
||||
}
|
||||
|
||||
@_alwaysEmitIntoClient
|
||||
public func sum() -> Float { x }
|
||||
}
|
||||
|
||||
extension Struct : AdditiveArithmetic {
|
||||
public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) }
|
||||
public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) }
|
||||
public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x }
|
||||
public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x }
|
||||
public static var zero: Self { Self(0) }
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
import MultiModuleStruct1
|
||||
import _Differentiation
|
||||
|
||||
extension Struct : Differentiable {
|
||||
@_alwaysEmitIntoClient
|
||||
@derivative(of: sum)
|
||||
public func _vjpSum() -> (
|
||||
value: Float, pullback: (Float) -> Self.TangentVector
|
||||
) {
|
||||
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
|
||||
}
|
||||
|
||||
@_alwaysEmitIntoClient
|
||||
@derivative(of: sum)
|
||||
public func _jvpSum() -> (
|
||||
value: Float, differential: (Self.TangentVector) -> Float
|
||||
) {
|
||||
(value: self.x, differential: { 42 * $0.x })
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
import MultiModuleStruct1
|
||||
import _Differentiation
|
||||
|
||||
extension Struct : Differentiable {
|
||||
@_alwaysEmitIntoClient
|
||||
@derivative(of: sum)
|
||||
public func _vjpSum() -> (
|
||||
value: Float, pullback: (Float) -> Self.TangentVector
|
||||
) {
|
||||
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
import MultiModuleStruct1
|
||||
import _Differentiation
|
||||
|
||||
extension Struct : Differentiable {
|
||||
@_alwaysEmitIntoClient
|
||||
@derivative(of: sum)
|
||||
public func _jvpSum() -> (
|
||||
value: Float, differential: (Self.TangentVector) -> Float
|
||||
) {
|
||||
(value: self.x, differential: { 42 * $0.x })
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
import _Differentiation
|
||||
|
||||
@_alwaysEmitIntoClient
|
||||
public func f(_ x: Float) -> Float {
|
||||
x
|
||||
}
|
||||
|
||||
@derivative(of: f)
|
||||
@_alwaysEmitIntoClient
|
||||
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
|
||||
(x, { 42 * $0 })
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
// REQUIRES: executable_test
|
||||
// RUN: %empty-directory(%t)
|
||||
|
||||
/// Note: we build just a module without a library since it would not contain any exported
|
||||
/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient.
|
||||
// RUN: %target-build-swift %S/Inputs/MultiFileModule/file1.swift %S/Inputs/MultiFileModule/file2.swift \
|
||||
// RUN: -emit-module -emit-module-path %t/MultiFileModule.swiftmodule -module-name MultiFileModule
|
||||
|
||||
// RUN: %target-build-swift -I%t %s -o %t/a.out %target-rpath(%t)
|
||||
// RUN: %target-run %t/a.out
|
||||
|
||||
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
|
||||
|
||||
import MultiFileModule
|
||||
import StdlibUnittest
|
||||
import _Differentiation
|
||||
|
||||
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
|
||||
|
||||
AlwaysEmitIntoClientTests.test("registration") {
|
||||
expectEqual(42, gradient(at: 0, of: f))
|
||||
expectEqual(42, gradient(at: 1, of: f))
|
||||
expectEqual(42, gradient(at: 2, of: f))
|
||||
}
|
||||
|
||||
runAllTests()
|
||||
|
||||
// CHECK: @"15MultiFileModule1fyS2fFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s15MultiFileModule1fyS2fFTJfSpSr", ptr @"$s15MultiFileModule1fyS2fFTJrSpSr" }
|
||||
@@ -0,0 +1,31 @@
|
||||
// REQUIRES: executable_test
|
||||
// RUN: %empty-directory(%t)
|
||||
|
||||
/// Note: we build just modules without libraries since they would not contain any exported
|
||||
/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient.
|
||||
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModule1)) %S/Inputs/MultiModule/file1.swift \
|
||||
// RUN: -emit-module -emit-module-path %t/MultiModule1.swiftmodule -module-name MultiModule1
|
||||
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModule2)) %S/Inputs/MultiModule/file2.swift \
|
||||
// RUN: -emit-module -emit-module-path %t/MultiModule2.swiftmodule -module-name MultiModule2 -I%t %target-rpath(%t)
|
||||
|
||||
// RUN: %target-build-swift -I%t %s -o %t/a.out %target-rpath(%t)
|
||||
// RUN: %target-run %t/a.out
|
||||
|
||||
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
|
||||
|
||||
import MultiModule1
|
||||
import MultiModule2
|
||||
import StdlibUnittest
|
||||
import _Differentiation
|
||||
|
||||
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
|
||||
|
||||
AlwaysEmitIntoClientTests.test("registration") {
|
||||
expectEqual(42, gradient(at: 0, of: f))
|
||||
expectEqual(42, gradient(at: 1, of: f))
|
||||
expectEqual(42, gradient(at: 2, of: f))
|
||||
}
|
||||
|
||||
runAllTests()
|
||||
|
||||
// CHECK: @"12MultiModule11fyS2fFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s12MultiModule11fyS2fFTJfSpSr", ptr @"$s12MultiModule11fyS2fFTJrSpSr" }
|
||||
@@ -0,0 +1,44 @@
|
||||
// REQUIRES: executable_test
|
||||
// RUN: %empty-directory(%t)
|
||||
|
||||
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleProtocol1)) %S/Inputs/MultiModuleProtocol/file1.swift \
|
||||
// RUN: -emit-module -emit-module-path %t/MultiModuleProtocol1.swiftmodule -module-name MultiModuleProtocol1
|
||||
|
||||
/// Note: we build just a module without a library since it would not contain any exported
|
||||
/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient.
|
||||
// RUN: %target-build-swift %S/Inputs/MultiModuleProtocol/file2.swift -emit-module -emit-module-path %t/MultiModuleProtocol2.swiftmodule \
|
||||
// RUN: -module-name MultiModuleProtocol2 -I%t -lMultiModuleProtocol1 %target-rpath(%t)
|
||||
|
||||
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleProtocol3)) %S/Inputs/MultiModuleProtocol/file3.swift \
|
||||
// RUN: -emit-module -emit-module-path %t/MultiModuleProtocol3.swiftmodule -module-name MultiModuleProtocol3 -I%t -L%t -lMultiModuleProtocol1 %target-rpath(%t)
|
||||
|
||||
/// Note: we enable forward-mode differentiation to automatically generate JVP for `foo`.
|
||||
/// It wraps `Protocol.sum` that has custom JVP defined in MultiModuleProtocol2, so we can test it.
|
||||
// RUN: %target-build-swift -Xfrontend -enable-experimental-forward-mode-differentiation \
|
||||
// RUN: -I%t -L%t %s -lMultiModuleProtocol1 -lMultiModuleProtocol3 -o %t/a.out %target-rpath(%t)
|
||||
// RUN: %target-run %t/a.out
|
||||
|
||||
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
|
||||
|
||||
import MultiModuleProtocol1
|
||||
import MultiModuleProtocol2
|
||||
import MultiModuleProtocol3
|
||||
import StdlibUnittest
|
||||
import _Differentiation
|
||||
|
||||
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
|
||||
|
||||
AlwaysEmitIntoClientTests.test("registration") {
|
||||
func foo<T: Protocol>(x: T) -> Float
|
||||
where T: Differentiable, T.TangentVector == T { x.sum() }
|
||||
expectEqual(Struct(42), pullback(at: Struct(0), of: foo)(1))
|
||||
expectEqual(Struct(42), pullback(at: Struct(1), of: foo)(1))
|
||||
expectEqual(Struct(42), pullback(at: Struct(2), of: foo)(1))
|
||||
expectEqual(42, differential(at: Struct(0), of: foo)(Struct(1)))
|
||||
expectEqual(42, differential(at: Struct(1), of: foo)(Struct(1)))
|
||||
expectEqual(42, differential(at: Struct(2), of: foo)(Struct(1)))
|
||||
}
|
||||
|
||||
runAllTests()
|
||||
|
||||
// CHECK: @"20MultiModuleProtocol18ProtocolPAAE3sumSfyFAaBRz16_Differentiation14DifferentiableRz13TangentVectorAeFPQzRszlWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s20MultiModuleProtocol18ProtocolPAAE3sumSfyFAaBRz16_Differentiation14DifferentiableRz13TangentVectorAeFPQzRszlTJfSpSr", ptr @"$s20MultiModuleProtocol18ProtocolPAAE3sumSfyFAaBRz16_Differentiation14DifferentiableRz13TangentVectorAeFPQzRszlTJrSpSr" }
|
||||
@@ -0,0 +1,36 @@
|
||||
// REQUIRES: executable_test
|
||||
// RUN: %empty-directory(%t)
|
||||
|
||||
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct1)) %S/Inputs/MultiModuleStruct/file1.swift \
|
||||
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct1.swiftmodule -module-name MultiModuleStruct1
|
||||
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct2)) %S/Inputs/MultiModuleStruct/file2.swift \
|
||||
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct2.swiftmodule -module-name MultiModuleStruct2 -I%t -L%t -lMultiModuleStruct1 %target-rpath(%t)
|
||||
|
||||
/// Note: we enable forward-mode differentiation to automatically generate JVP for `foo`.
|
||||
/// It wraps `Struct.sum` that has custom JVP defined in MultiModuleStruct2, so we can test it.
|
||||
// RUN: %target-build-swift -Xfrontend -enable-experimental-forward-mode-differentiation \
|
||||
// RUN: -I%t -L%t %s -lMultiModuleStruct1 -lMultiModuleStruct2 -o %t/a.out %target-rpath(%t)
|
||||
// RUN: %target-run %t/a.out
|
||||
|
||||
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
|
||||
|
||||
import MultiModuleStruct1
|
||||
import MultiModuleStruct2
|
||||
import StdlibUnittest
|
||||
import _Differentiation
|
||||
|
||||
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
|
||||
|
||||
AlwaysEmitIntoClientTests.test("registration") {
|
||||
func foo(x: Struct) -> Float { x.sum() }
|
||||
expectEqual(Struct(42), pullback(at: Struct(0), of: foo)(1))
|
||||
expectEqual(Struct(42), pullback(at: Struct(1), of: foo)(1))
|
||||
expectEqual(Struct(42), pullback(at: Struct(2), of: foo)(1))
|
||||
expectEqual(42, differential(at: Struct(0), of: foo)(Struct(1)))
|
||||
expectEqual(42, differential(at: Struct(1), of: foo)(Struct(1)))
|
||||
expectEqual(42, differential(at: Struct(2), of: foo)(Struct(1)))
|
||||
}
|
||||
|
||||
runAllTests()
|
||||
|
||||
// CHECK: @"18MultiModuleStruct16StructV3sumSfyFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJfSpSr", ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJrSpSr" }
|
||||
@@ -0,0 +1,37 @@
|
||||
// REQUIRES: executable_test
|
||||
// RUN: %empty-directory(%t)
|
||||
|
||||
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct1)) %S/Inputs/MultiModuleStruct/file1.swift \
|
||||
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct1.swiftmodule -module-name MultiModuleStruct1
|
||||
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct2NoJVP)) %S/Inputs/MultiModuleStruct/file2_no_jvp.swift \
|
||||
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct2NoJVP.swiftmodule -module-name MultiModuleStruct2NoJVP -I%t -L%t -lMultiModuleStruct1 %target-rpath(%t)
|
||||
|
||||
/// Note: we enable forward-mode differentiation to automatically generate JVP for `foo`.
|
||||
/// It wraps `Struct.sum` that has custom JVP defined in MultiModuleStruct2, so we can test it.
|
||||
// RUN: %target-build-swift -Xfrontend -enable-experimental-forward-mode-differentiation \
|
||||
// RUN: -I%t -L%t %s -lMultiModuleStruct1 -lMultiModuleStruct2NoJVP -o %t/a.out %target-rpath(%t)
|
||||
// RUN: %target-run %t/a.out
|
||||
|
||||
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
|
||||
|
||||
import MultiModuleStruct1
|
||||
import MultiModuleStruct2NoJVP
|
||||
import StdlibUnittest
|
||||
import _Differentiation
|
||||
|
||||
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
|
||||
|
||||
AlwaysEmitIntoClientTests.test("registration") {
|
||||
func foo(x: Struct) -> Float { x.sum() }
|
||||
expectEqual(Struct(42), pullback(at: Struct(0), of: foo)(1))
|
||||
expectEqual(Struct(42), pullback(at: Struct(1), of: foo)(1))
|
||||
expectEqual(Struct(42), pullback(at: Struct(2), of: foo)(1))
|
||||
/// Custom JVP for Struct.sum is not provided, a JVP causing fatal error is emitted.
|
||||
expectCrash{differential(at: Struct(0), of: foo)(Struct(1))}
|
||||
expectCrash{differential(at: Struct(1), of: foo)(Struct(1))}
|
||||
expectCrash{differential(at: Struct(2), of: foo)(Struct(1))}
|
||||
}
|
||||
|
||||
runAllTests()
|
||||
|
||||
// CHECK: @"18MultiModuleStruct16StructV3sumSfyFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJfSpSr", ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJrSpSr" }
|
||||
@@ -0,0 +1,26 @@
|
||||
// REQUIRES: executable_test
|
||||
// RUN: %empty-directory(%t)
|
||||
|
||||
// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct1)) %S/Inputs/MultiModuleStruct/file1.swift \
|
||||
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct1.swiftmodule -module-name MultiModuleStruct1
|
||||
// RUN: not %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct2NoVJP)) %S/Inputs/MultiModuleStruct/file2_no_vjp.swift \
|
||||
// RUN: -emit-module -emit-module-path %t/MultiModuleStruct2NoVJP.swiftmodule -module-name MultiModuleStruct2NoVJP -I%t -L%t -lMultiModuleStruct1 %target-rpath(%t) 2>&1 | \
|
||||
// RUN: %FileCheck %s
|
||||
|
||||
// CHECK: file2_no_vjp.swift:6:4: error: function is not differentiable
|
||||
|
||||
import MultiModuleStruct1
|
||||
import MultiModuleStruct2
|
||||
import StdlibUnittest
|
||||
import _Differentiation
|
||||
|
||||
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
|
||||
|
||||
AlwaysEmitIntoClientTests.test("registration") {
|
||||
func foo(x: Struct) -> Float { x.sum() }
|
||||
expectEqual(42, differential(at: Struct(0), of: foo)(Struct(1)))
|
||||
expectEqual(42, differential(at: Struct(1), of: foo)(Struct(1)))
|
||||
expectEqual(42, differential(at: Struct(2), of: foo)(Struct(1)))
|
||||
}
|
||||
|
||||
runAllTests()
|
||||
@@ -0,0 +1,28 @@
|
||||
// REQUIRES: executable_test
|
||||
// RUN: %empty-directory(%t)
|
||||
|
||||
/// Note: we build just a module without a library since it would not contain any exported
|
||||
/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient.
|
||||
// RUN: %target-build-swift %S/Inputs/SingleFileModule/file.swift -emit-module \
|
||||
// RUN: -emit-module-path %t/SingleFileModule.swiftmodule -module-name SingleFileModule
|
||||
|
||||
// RUN: %target-build-swift -I%t %s -o %t/a.out %target-rpath(%t)
|
||||
// RUN: %target-run %t/a.out
|
||||
|
||||
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
|
||||
|
||||
import SingleFileModule
|
||||
import StdlibUnittest
|
||||
import _Differentiation
|
||||
|
||||
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
|
||||
|
||||
AlwaysEmitIntoClientTests.test("registration") {
|
||||
expectEqual(42, gradient(at: 0, of: f))
|
||||
expectEqual(42, gradient(at: 1, of: f))
|
||||
expectEqual(42, gradient(at: 2, of: f))
|
||||
}
|
||||
|
||||
runAllTests()
|
||||
|
||||
// CHECK: @"16SingleFileModule1fyS2fFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s16SingleFileModule1fyS2fFTJfSpSr", ptr @"$s16SingleFileModule1fyS2fFTJrSpSr" }
|
||||
Reference in New Issue
Block a user