mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Merge pull request #84800 from xedin/remove-csapply-operator-devirt
[CSApply] Don't attempt operator devirtualization
This commit is contained in:
@@ -43,6 +43,7 @@
|
|||||||
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
|
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
|
||||||
#include "swift/SILOptimizer/PassManager/Passes.h"
|
#include "swift/SILOptimizer/PassManager/Passes.h"
|
||||||
#include "swift/SILOptimizer/PassManager/Transforms.h"
|
#include "swift/SILOptimizer/PassManager/Transforms.h"
|
||||||
|
#include "swift/SILOptimizer/Utils/Devirtualize.h"
|
||||||
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
|
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
|
||||||
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
|
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
|
||||||
#include "llvm/ADT/APSInt.h"
|
#include "llvm/ADT/APSInt.h"
|
||||||
@@ -100,6 +101,24 @@ private:
|
|||||||
SILBuilder &builder, SILLocation loc,
|
SILBuilder &builder, SILLocation loc,
|
||||||
DifferentiationInvoker invoker);
|
DifferentiationInvoker invoker);
|
||||||
|
|
||||||
|
/// Creates and canonicalizes private differentiability witness for
|
||||||
|
/// `originalFn`, differentiating with respect to
|
||||||
|
/// `desiredParameterIndices`. Returns `nullptr` on failure signifying that a
|
||||||
|
/// diagnostic has been emitted using `invoker`.
|
||||||
|
SILDifferentiabilityWitness *createPrivateDifferentiabilityWitness(
|
||||||
|
CanSILFunctionType originalFnTy, SILFunction *originalFn,
|
||||||
|
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
|
||||||
|
SILValue original, DifferentiationInvoker invoker);
|
||||||
|
|
||||||
|
/// Resolves differentiability witness for `originalFn`. Either by looking up
|
||||||
|
/// for registered one with respect to posible superset of `desiredIndices`,
|
||||||
|
/// or creating a new private one. Returns `nullptr` on failure signifying
|
||||||
|
/// that a diagnostic has been emitted using `invoker`.
|
||||||
|
SILDifferentiabilityWitness *getOrCreateMinimalDifferentiabilitywitness(
|
||||||
|
CanSILFunctionType originalFnTy, SILFunction *originalFn,
|
||||||
|
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
|
||||||
|
SILValue original, DifferentiationInvoker invoker);
|
||||||
|
|
||||||
/// Emits a reference to a derivative function of `original`, differentiated
|
/// Emits a reference to a derivative function of `original`, differentiated
|
||||||
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
|
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
|
||||||
/// the derivative function and the actual indices that the derivative
|
/// the derivative function and the actual indices that the derivative
|
||||||
@@ -527,12 +546,142 @@ SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction(
|
|||||||
return silFnIt->second;
|
return silFnIt->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SILDifferentiabilityWitness *
|
||||||
|
DifferentiationTransformer::createPrivateDifferentiabilityWitness(
|
||||||
|
CanSILFunctionType originalFnTy, SILFunction *originalFn,
|
||||||
|
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
|
||||||
|
SILValue original, DifferentiationInvoker invoker) {
|
||||||
|
|
||||||
|
// Check non-differentiable cases before creating a new private
|
||||||
|
// differentiability witness.
|
||||||
|
|
||||||
|
// If the function is intentionally marked as being opaque to
|
||||||
|
// differentiation, then we should not create a task for it.
|
||||||
|
if (originalFn->hasSemanticsAttr("autodiff.opaque")) {
|
||||||
|
context.emitNondifferentiabilityError(
|
||||||
|
original, invoker, diag::autodiff_opaque_function_not_differentiable);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check and diagnose non-differentiable arguments.
|
||||||
|
for (auto [paramIndex, param] :
|
||||||
|
llvm::enumerate(originalFnTy->getParameters())) {
|
||||||
|
if (!desiredParameterIndices->contains(paramIndex))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
SILType paramType = param.getSILStorageInterfaceType();
|
||||||
|
if (!paramType.isDifferentiable(context.getModule())) {
|
||||||
|
context.emitNondifferentiabilityError(
|
||||||
|
original, invoker, diag::autodiff_nondifferentiable_argument);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check and diagnose non-differentiable results.
|
||||||
|
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
|
||||||
|
unsigned firstYieldResultIndex =
|
||||||
|
firstSemanticParamResultIdx +
|
||||||
|
originalFnTy->getNumAutoDiffSemanticResultsParameters();
|
||||||
|
for (auto resultIndex : desiredResultIndices->getIndices()) {
|
||||||
|
SILType resultType;
|
||||||
|
if (resultIndex >= firstYieldResultIndex) {
|
||||||
|
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
|
||||||
|
auto yield = originalFnTy->getYields()[yieldResultIndex];
|
||||||
|
// We can only differentiate indirect yields. This should be diagnosed
|
||||||
|
// earlier in VJPCloner.
|
||||||
|
assert(yield.isAutoDiffSemanticResult() && "unsupported result");
|
||||||
|
resultType = yield.getSILStorageInterfaceType();
|
||||||
|
} else if (resultIndex >= firstSemanticParamResultIdx) {
|
||||||
|
auto semanticResultParamIdx = resultIndex - firstSemanticParamResultIdx;
|
||||||
|
auto semanticResultParam = *std::next(
|
||||||
|
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
|
||||||
|
semanticResultParamIdx);
|
||||||
|
resultType = semanticResultParam.getSILStorageInterfaceType();
|
||||||
|
} else {
|
||||||
|
resultType =
|
||||||
|
originalFnTy->getResults()[resultIndex].getSILStorageInterfaceType();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!resultType || !resultType.isDifferentiable(context.getModule())) {
|
||||||
|
context.emitNondifferentiabilityError(
|
||||||
|
original, invoker, diag::autodiff_nondifferentiable_result);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check and diagnose external declarations.
|
||||||
|
if (originalFn->isExternalDeclaration()) {
|
||||||
|
context.emitNondifferentiabilityError(
|
||||||
|
original, invoker, diag::autodiff_external_nondifferentiable_function);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Soundness check passed. Create a new differentiability witness
|
||||||
|
GenericSignature contextualDerivativeGenSig = GenericSignature();
|
||||||
|
if (invoker.getKind() ==
|
||||||
|
DifferentiationInvoker::Kind::IndirectDifferentiation)
|
||||||
|
contextualDerivativeGenSig = invoker.getIndirectDifferentiation()
|
||||||
|
.second->getDerivativeGenericSignature();
|
||||||
|
auto derivativeConstrainedGenSig =
|
||||||
|
autodiff::getConstrainedDerivativeGenericSignature(
|
||||||
|
originalFnTy, desiredParameterIndices, desiredResultIndices,
|
||||||
|
contextualDerivativeGenSig, LookUpConformanceInModule());
|
||||||
|
|
||||||
|
auto *witness = SILDifferentiabilityWitness::createDefinition(
|
||||||
|
context.getModule(), SILLinkage::Private, originalFn,
|
||||||
|
DifferentiabilityKind::Reverse, desiredParameterIndices,
|
||||||
|
desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr,
|
||||||
|
/*vjp*/ nullptr, /*isSerialized*/ false);
|
||||||
|
if (canonicalizeDifferentiabilityWitness(witness, invoker, IsNotSerialized))
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
return witness;
|
||||||
|
}
|
||||||
|
|
||||||
|
SILDifferentiabilityWitness *
|
||||||
|
DifferentiationTransformer::getOrCreateMinimalDifferentiabilitywitness(
|
||||||
|
CanSILFunctionType originalFnTy, SILFunction *originalFn,
|
||||||
|
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
|
||||||
|
SILValue original, DifferentiationInvoker invoker) {
|
||||||
|
// NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has
|
||||||
|
// parameters corresponding to captured variables.
|
||||||
|
// TODO: If possible, change `autodiff::getLoweredParameterIndices` to
|
||||||
|
// take `CaptureInfo` into account.
|
||||||
|
if (originalFnTy->getNumParameters() >
|
||||||
|
desiredParameterIndices->getCapacity()) {
|
||||||
|
desiredParameterIndices = desiredParameterIndices->extendingCapacity(
|
||||||
|
context.getASTContext(), originalFnTy->getNumParameters());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up a differentiability witness with the exact configuration.
|
||||||
|
auto *minimalWitness = getExactDifferentiabilityWitness(
|
||||||
|
context.getModule(), originalFn, desiredParameterIndices,
|
||||||
|
desiredResultIndices);
|
||||||
|
|
||||||
|
// Otherwise, look up a differentiability witness with a minimal superset
|
||||||
|
// configuration.
|
||||||
|
if (!minimalWitness)
|
||||||
|
minimalWitness = getOrCreateMinimalASTDifferentiabilityWitness(
|
||||||
|
context.getModule(), originalFn, DifferentiabilityKind::Reverse,
|
||||||
|
desiredParameterIndices, desiredResultIndices);
|
||||||
|
|
||||||
|
// Finally, try to create private differentiability witness
|
||||||
|
if (!minimalWitness)
|
||||||
|
minimalWitness = createPrivateDifferentiabilityWitness(
|
||||||
|
originalFnTy, originalFn, desiredParameterIndices, desiredResultIndices,
|
||||||
|
original, invoker);
|
||||||
|
|
||||||
|
return minimalWitness;
|
||||||
|
}
|
||||||
|
|
||||||
std::optional<std::pair<SILValue, AutoDiffConfig>>
|
std::optional<std::pair<SILValue, AutoDiffConfig>>
|
||||||
DifferentiationTransformer::emitDerivativeFunctionReference(
|
DifferentiationTransformer::emitDerivativeFunctionReference(
|
||||||
SILBuilder &builder, const AutoDiffConfig &desiredConfig,
|
SILBuilder &builder, const AutoDiffConfig &desiredConfig,
|
||||||
AutoDiffDerivativeFunctionKind kind, SILValue original,
|
AutoDiffDerivativeFunctionKind kind, SILValue original,
|
||||||
DifferentiationInvoker invoker,
|
DifferentiationInvoker invoker,
|
||||||
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
|
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
|
||||||
|
SILFunction *parentFn = original->getFunction();
|
||||||
|
|
||||||
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
|
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
|
||||||
// matches the given kind and desired differentiation parameter indices,
|
// matches the given kind and desired differentiation parameter indices,
|
||||||
// simply extract the derivative function of its function operand, retain the
|
// simply extract the derivative function of its function operand, retain the
|
||||||
@@ -576,111 +725,16 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
|
|||||||
auto loc = originalFRI->getLoc();
|
auto loc = originalFRI->getLoc();
|
||||||
auto *originalFn = originalFRI->getReferencedFunction();
|
auto *originalFn = originalFRI->getReferencedFunction();
|
||||||
auto originalFnTy = originalFn->getLoweredFunctionType();
|
auto originalFnTy = originalFn->getLoweredFunctionType();
|
||||||
auto *desiredParameterIndices = desiredConfig.parameterIndices;
|
|
||||||
auto *desiredResultIndices = desiredConfig.resultIndices;
|
auto *minimalWitness = getOrCreateMinimalDifferentiabilitywitness(
|
||||||
// NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has
|
originalFnTy, originalFn, desiredConfig.parameterIndices,
|
||||||
// parameters corresponding to captured variables.
|
desiredConfig.resultIndices, original, invoker);
|
||||||
// TODO: If possible, change `autodiff::getLoweredParameterIndices` to
|
|
||||||
// take `CaptureInfo` into account.
|
// All non-differentiability cases should be diagnosted
|
||||||
if (originalFnTy->getNumParameters() >
|
|
||||||
desiredParameterIndices->getCapacity()) {
|
|
||||||
desiredParameterIndices = desiredParameterIndices->extendingCapacity(
|
|
||||||
context.getASTContext(), originalFnTy->getNumParameters());
|
|
||||||
}
|
|
||||||
// Look up a differentiability witness with the exact configuration.
|
|
||||||
auto *minimalWitness = getExactDifferentiabilityWitness(
|
|
||||||
context.getModule(), originalFn, desiredParameterIndices,
|
|
||||||
desiredResultIndices);
|
|
||||||
// Otherwise, look up a differentiability witness with a minimal superset
|
|
||||||
// configuration.
|
|
||||||
if (!minimalWitness)
|
if (!minimalWitness)
|
||||||
minimalWitness = getOrCreateMinimalASTDifferentiabilityWitness(
|
return std::nullopt;
|
||||||
context.getModule(), originalFn, DifferentiabilityKind::Reverse,
|
|
||||||
desiredParameterIndices, desiredResultIndices);
|
if (parentFn->isSerialized()) {
|
||||||
// If no minimal witness exists, check non-differentiable cases before
|
|
||||||
// creating a new private differentiability witness.
|
|
||||||
if (!minimalWitness) {
|
|
||||||
// If the function is intentionally marked as being opaque to
|
|
||||||
// differentiation, then we should not create a task for it.
|
|
||||||
if (originalFn->hasSemanticsAttr("autodiff.opaque")) {
|
|
||||||
context.emitNondifferentiabilityError(
|
|
||||||
original, invoker,
|
|
||||||
diag::autodiff_opaque_function_not_differentiable);
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
// Check and diagnose non-differentiable arguments.
|
|
||||||
auto originalFnTy = originalFn->getLoweredFunctionType();
|
|
||||||
for (unsigned paramIndex : range(originalFnTy->getNumParameters())) {
|
|
||||||
if (desiredConfig.isWrtParameter(paramIndex) &&
|
|
||||||
!originalFnTy->getParameters()[paramIndex]
|
|
||||||
.getSILStorageInterfaceType()
|
|
||||||
.isDifferentiable(context.getModule())) {
|
|
||||||
auto diag = context.emitNondifferentiabilityError(
|
|
||||||
original, invoker, diag::autodiff_nondifferentiable_argument);
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Check and diagnose non-differentiable results.
|
|
||||||
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
|
|
||||||
unsigned firstYieldResultIndex = originalFnTy->getNumResults() +
|
|
||||||
originalFnTy->getNumAutoDiffSemanticResultsParameters();
|
|
||||||
for (auto resultIndex : desiredResultIndices->getIndices()) {
|
|
||||||
SILType resultType;
|
|
||||||
if (resultIndex >= firstYieldResultIndex) {
|
|
||||||
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
|
|
||||||
auto yield = originalFnTy->getYields()[yieldResultIndex];
|
|
||||||
// We can only differentiate indirect yields. This should be diagnosed
|
|
||||||
// earlier in VJPCloner.
|
|
||||||
assert(yield.isAutoDiffSemanticResult() && "unsupported result");
|
|
||||||
resultType = yield.getSILStorageInterfaceType();
|
|
||||||
} else if (resultIndex >= firstSemanticParamResultIdx) {
|
|
||||||
auto semanticResultParamIdx = resultIndex - firstSemanticParamResultIdx;
|
|
||||||
auto semanticResultParam =
|
|
||||||
*std::next(originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
|
|
||||||
semanticResultParamIdx);
|
|
||||||
resultType = semanticResultParam.getSILStorageInterfaceType();
|
|
||||||
} else {
|
|
||||||
resultType = originalFnTy->getResults()[resultIndex]
|
|
||||||
.getSILStorageInterfaceType();
|
|
||||||
}
|
|
||||||
if (!resultType || !resultType.isDifferentiable(context.getModule())) {
|
|
||||||
context.emitNondifferentiabilityError(
|
|
||||||
original, invoker, diag::autodiff_nondifferentiable_result);
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Check and diagnose external declarations.
|
|
||||||
if (originalFn->isExternalDeclaration()) {
|
|
||||||
context.emitNondifferentiabilityError(
|
|
||||||
original, invoker,
|
|
||||||
diag::autodiff_external_nondifferentiable_function);
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
// Soundness check passed. Create a new differentiability witness and
|
|
||||||
// canonicalize it.
|
|
||||||
GenericSignature contextualDerivativeGenSig = GenericSignature();
|
|
||||||
if (invoker.getKind() ==
|
|
||||||
DifferentiationInvoker::Kind::IndirectDifferentiation)
|
|
||||||
contextualDerivativeGenSig =
|
|
||||||
invoker.getIndirectDifferentiation()
|
|
||||||
.second->getDerivativeGenericSignature();
|
|
||||||
auto derivativeConstrainedGenSig =
|
|
||||||
autodiff::getConstrainedDerivativeGenericSignature(
|
|
||||||
originalFn->getLoweredFunctionType(),
|
|
||||||
desiredParameterIndices, desiredResultIndices,
|
|
||||||
contextualDerivativeGenSig,
|
|
||||||
LookUpConformanceInModule());
|
|
||||||
minimalWitness = SILDifferentiabilityWitness::createDefinition(
|
|
||||||
context.getModule(), SILLinkage::Private, originalFn,
|
|
||||||
DifferentiabilityKind::Reverse, desiredParameterIndices,
|
|
||||||
desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr,
|
|
||||||
/*vjp*/ nullptr, /*isSerialized*/ false);
|
|
||||||
if (canonicalizeDifferentiabilityWitness(minimalWitness, invoker,
|
|
||||||
IsNotSerialized))
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
assert(minimalWitness);
|
|
||||||
if (original->getFunction()->isSerialized()) {
|
|
||||||
bool isWitnessPublic;
|
bool isWitnessPublic;
|
||||||
if (SILFunction *unwrappedFn =
|
if (SILFunction *unwrappedFn =
|
||||||
getUnwrappedCurryThunkFunction(originalFn)) {
|
getUnwrappedCurryThunkFunction(originalFn)) {
|
||||||
@@ -691,15 +745,12 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
|
|||||||
// private linkage.
|
// private linkage.
|
||||||
isWitnessPublic = hasPublicVisibility(unwrappedFn->getLinkage());
|
isWitnessPublic = hasPublicVisibility(unwrappedFn->getLinkage());
|
||||||
} else if (originalFn->getDeclRef().getAbstractClosureExpr() &&
|
} else if (originalFn->getDeclRef().getAbstractClosureExpr() &&
|
||||||
originalFRI->getFunction()
|
parentFn->getDeclRef().isDefaultArgGenerator()) {
|
||||||
->getDeclRef()
|
|
||||||
.isDefaultArgGenerator()) {
|
|
||||||
// If we reference a closure from inside default argument generator,
|
// If we reference a closure from inside default argument generator,
|
||||||
// check against generator's visibility. If the function having this
|
// check against generator's visibility. If the function having this
|
||||||
// default argument has public visibility, it's OK to have a closure
|
// default argument has public visibility, it's OK to have a closure
|
||||||
// (which always has private visibility) as its default value.
|
// (which always has private visibility) as its default value.
|
||||||
isWitnessPublic =
|
isWitnessPublic = hasPublicVisibility(parentFn->getLinkage());
|
||||||
hasPublicVisibility(originalFRI->getFunction()->getLinkage());
|
|
||||||
} else {
|
} else {
|
||||||
isWitnessPublic = hasPublicVisibility(minimalWitness->getLinkage());
|
isWitnessPublic = hasPublicVisibility(minimalWitness->getLinkage());
|
||||||
}
|
}
|
||||||
@@ -709,16 +760,16 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
|
|||||||
// FIXME: This is not a very robust way of determining if the function
|
// FIXME: This is not a very robust way of determining if the function
|
||||||
// is a default argument. Also, we have not exhaustively listed all the
|
// is a default argument. Also, we have not exhaustively listed all the
|
||||||
// kinds of fragility.
|
// kinds of fragility.
|
||||||
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
|
if (parentFn->getLinkage() == SILLinkage::PublicNonABI)
|
||||||
fragileKind = DefaultArgument;
|
fragileKind = DefaultArgument;
|
||||||
context.emitNondifferentiabilityError(
|
context.emitNondifferentiabilityError(
|
||||||
original, invoker, diag::autodiff_private_derivative_from_fragile,
|
original, invoker, diag::autodiff_private_derivative_from_fragile,
|
||||||
fragileKind,
|
fragileKind,
|
||||||
isa_and_nonnull<AbstractClosureExpr>(
|
isa_and_nonnull<AbstractClosureExpr>(loc.getAsASTNode<Expr>()));
|
||||||
originalFRI->getLoc().getAsASTNode<Expr>()));
|
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(TF-482): Move generic requirement checking logic to
|
// TODO(TF-482): Move generic requirement checking logic to
|
||||||
// `getExactDifferentiabilityWitness` and
|
// `getExactDifferentiabilityWitness` and
|
||||||
// `getOrCreateMinimalASTDifferentiabilityWitness`.
|
// `getOrCreateMinimalASTDifferentiabilityWitness`.
|
||||||
@@ -726,18 +777,17 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
|
|||||||
// By default, use the forwarding substitution map of the original function.
|
// By default, use the forwarding substitution map of the original function.
|
||||||
// If the original callee is a `partial_apply` or `apply` instruction, use
|
// If the original callee is a `partial_apply` or `apply` instruction, use
|
||||||
// its substitution map instead.
|
// its substitution map instead.
|
||||||
auto substMap = original->getFunction()->getForwardingSubstitutionMap();
|
auto substMap = parentFn->getForwardingSubstitutionMap();
|
||||||
if (auto *pai =
|
if (auto *pai = peerThroughFunctionConversions<PartialApplyInst>(original))
|
||||||
peerThroughFunctionConversions<PartialApplyInst>(original)) {
|
|
||||||
substMap = pai->getSubstitutionMap();
|
substMap = pai->getSubstitutionMap();
|
||||||
} else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original)) {
|
else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original))
|
||||||
substMap = ai->getSubstitutionMap();
|
substMap = ai->getSubstitutionMap();
|
||||||
}
|
|
||||||
if (diagnoseUnsatisfiedRequirements(
|
if (diagnoseUnsatisfiedRequirements(
|
||||||
context, original->getType().castTo<SILFunctionType>(),
|
context, original->getType().castTo<SILFunctionType>(),
|
||||||
minimalWitness->getDerivativeGenericSignature(), substMap, invoker,
|
minimalWitness->getDerivativeGenericSignature(), substMap, invoker,
|
||||||
original.getLoc().getSourceLoc()))
|
original.getLoc().getSourceLoc()))
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
DifferentiabilityWitnessFunctionKind witnessKind;
|
DifferentiabilityWitnessFunctionKind witnessKind;
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
case AutoDiffDerivativeFunctionKind::JVP:
|
case AutoDiffDerivativeFunctionKind::JVP:
|
||||||
@@ -764,6 +814,79 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
|
|||||||
if (auto *witnessMethod =
|
if (auto *witnessMethod =
|
||||||
peerThroughFunctionConversions<WitnessMethodInst>(original)) {
|
peerThroughFunctionConversions<WitnessMethodInst>(original)) {
|
||||||
auto loc = witnessMethod->getLoc();
|
auto loc = witnessMethod->getLoc();
|
||||||
|
// See if we can derive derivatives for a method statically
|
||||||
|
auto [method, table] = lookUpFunctionInWitnessTable(
|
||||||
|
witnessMethod, SILModule::LinkingMode::LinkNormal);
|
||||||
|
if (method) {
|
||||||
|
auto *originalFn = method;
|
||||||
|
auto originalFnTy = method->getLoweredFunctionTypeInContext(
|
||||||
|
TypeExpansionContext(*original->getFunction()));
|
||||||
|
|
||||||
|
auto *minimalWitness = getOrCreateMinimalDifferentiabilitywitness(
|
||||||
|
originalFnTy, originalFn, desiredConfig.parameterIndices,
|
||||||
|
desiredConfig.resultIndices, original, invoker);
|
||||||
|
|
||||||
|
// All non-differentiability cases should be diagnosted
|
||||||
|
if (!minimalWitness)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
if (parentFn->isSerialized() &&
|
||||||
|
!hasPublicVisibility(minimalWitness->getLinkage())) {
|
||||||
|
enum { Inlinable = 0, DefaultArgument = 1 };
|
||||||
|
unsigned fragileKind = Inlinable;
|
||||||
|
// FIXME: This is not a very robust way of determining if the function
|
||||||
|
// is a default argument. Also, we have not exhaustively listed all the
|
||||||
|
// kinds of fragility.
|
||||||
|
if (parentFn->getLinkage() == SILLinkage::PublicNonABI)
|
||||||
|
fragileKind = DefaultArgument;
|
||||||
|
context.emitNondifferentiabilityError(
|
||||||
|
original, invoker, diag::autodiff_private_derivative_from_fragile,
|
||||||
|
fragileKind,
|
||||||
|
isa_and_nonnull<AbstractClosureExpr>(loc.getAsASTNode<Expr>()));
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto substMap = parentFn->getForwardingSubstitutionMap();
|
||||||
|
if (auto *pai =
|
||||||
|
peerThroughFunctionConversions<PartialApplyInst>(original))
|
||||||
|
substMap =
|
||||||
|
getWitnessMethodSubstitutions(context.getModule(), pai, originalFn,
|
||||||
|
witnessMethod->getConformance());
|
||||||
|
else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original))
|
||||||
|
substMap =
|
||||||
|
getWitnessMethodSubstitutions(context.getModule(), ai, originalFn,
|
||||||
|
witnessMethod->getConformance());
|
||||||
|
if (diagnoseUnsatisfiedRequirements(
|
||||||
|
context, original->getType().castTo<SILFunctionType>(),
|
||||||
|
minimalWitness->getDerivativeGenericSignature(), substMap,
|
||||||
|
invoker, original.getLoc().getSourceLoc()))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
DifferentiabilityWitnessFunctionKind witnessKind;
|
||||||
|
switch (kind) {
|
||||||
|
case AutoDiffDerivativeFunctionKind::JVP:
|
||||||
|
witnessKind = DifferentiabilityWitnessFunctionKind::JVP;
|
||||||
|
break;
|
||||||
|
case AutoDiffDerivativeFunctionKind::VJP:
|
||||||
|
witnessKind = DifferentiabilityWitnessFunctionKind::VJP;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto *derivativeFnRef = builder.createDifferentiabilityWitnessFunction(
|
||||||
|
loc, witnessKind, minimalWitness);
|
||||||
|
|
||||||
|
auto convertedRef = reapplyFunctionConversion(
|
||||||
|
context, derivativeFnRef, witnessMethod, original, builder, loc,
|
||||||
|
newBuffersToDealloc, desiredConfig.parameterIndices,
|
||||||
|
desiredConfig.resultIndices,
|
||||||
|
derivativeFnRef->getType()
|
||||||
|
.getASTType()
|
||||||
|
->castTo<SILFunctionType>()
|
||||||
|
->getSubstGenericSignature());
|
||||||
|
|
||||||
|
return std::make_pair(convertedRef, minimalWitness->getConfig());
|
||||||
|
}
|
||||||
|
|
||||||
auto requirementDeclRef = witnessMethod->getMember();
|
auto requirementDeclRef = witnessMethod->getMember();
|
||||||
auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl();
|
auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl();
|
||||||
// If requirement declaration does not have any derivative function
|
// If requirement declaration does not have any derivative function
|
||||||
|
|||||||
@@ -588,66 +588,6 @@ namespace {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns None if the AST does not contain enough information to recover
|
|
||||||
// substitutions; this is different from an Optional(SubstitutionMap()),
|
|
||||||
// indicating a valid call to a non-generic operator.
|
|
||||||
std::optional<SubstitutionMap> getOperatorSubstitutions(ValueDecl *witness,
|
|
||||||
Type refType) {
|
|
||||||
// We have to recover substitutions in this hacky way because
|
|
||||||
// the AST does not retain enough information to devirtualize
|
|
||||||
// calls like this.
|
|
||||||
auto witnessType = witness->getInterfaceType();
|
|
||||||
|
|
||||||
// Compute the substitutions.
|
|
||||||
auto *gft = witnessType->getAs<GenericFunctionType>();
|
|
||||||
if (gft == nullptr) {
|
|
||||||
if (refType->isEqual(witnessType))
|
|
||||||
return SubstitutionMap();
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto sig = gft->getGenericSignature();
|
|
||||||
auto *env = sig.getGenericEnvironment();
|
|
||||||
|
|
||||||
witnessType = FunctionType::get(gft->getParams(),
|
|
||||||
gft->getResult(),
|
|
||||||
gft->getExtInfo());
|
|
||||||
witnessType = env->mapTypeIntoContext(witnessType);
|
|
||||||
|
|
||||||
TypeSubstitutionMap subs;
|
|
||||||
auto substType = witnessType->substituteBindingsTo(
|
|
||||||
refType,
|
|
||||||
[&](ArchetypeType *origType, CanType substType) -> CanType {
|
|
||||||
if (auto gpType = dyn_cast<GenericTypeParamType>(
|
|
||||||
origType->getInterfaceType()->getCanonicalType()))
|
|
||||||
subs[gpType] = substType;
|
|
||||||
|
|
||||||
return substType;
|
|
||||||
});
|
|
||||||
|
|
||||||
// If substitution failed, it means that the protocol requirement type
|
|
||||||
// and the witness type did not match up. The only time that this
|
|
||||||
// should happen is when the witness is defined in a base class and
|
|
||||||
// the actual call uses a derived class. For example,
|
|
||||||
//
|
|
||||||
// protocol P { func +(lhs: Self, rhs: Self) }
|
|
||||||
// class Base : P { func +(lhs: Base, rhs: Base) {} }
|
|
||||||
// class Derived : Base {}
|
|
||||||
//
|
|
||||||
// If we enter this code path with two operands of type Derived,
|
|
||||||
// we know we're calling the protocol requirement P.+, with a
|
|
||||||
// substituted type of (Derived, Derived) -> (). But the type of
|
|
||||||
// the witness is (Base, Base) -> (). Just bail out and make a
|
|
||||||
// witness method call in this rare case; SIL mandatory optimizations
|
|
||||||
// will likely devirtualize it anyway.
|
|
||||||
if (!substType)
|
|
||||||
return std::nullopt;
|
|
||||||
|
|
||||||
return SubstitutionMap::get(sig,
|
|
||||||
QueryTypeSubstitutionMap{subs},
|
|
||||||
LookUpConformanceInModule());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Determine whether the given reference is to a method on
|
/// Determine whether the given reference is to a method on
|
||||||
/// a remote distributed actor in the given context.
|
/// a remote distributed actor in the given context.
|
||||||
bool isDistributedThunk(ConcreteDeclRef ref, Expr *context);
|
bool isDistributedThunk(ConcreteDeclRef ref, Expr *context);
|
||||||
@@ -674,65 +614,6 @@ namespace {
|
|||||||
|
|
||||||
auto baseTy = getBaseType(adjustedFullType->castTo<FunctionType>());
|
auto baseTy = getBaseType(adjustedFullType->castTo<FunctionType>());
|
||||||
|
|
||||||
// Handle operator requirements found in protocols.
|
|
||||||
if (auto proto = dyn_cast<ProtocolDecl>(decl->getDeclContext())) {
|
|
||||||
bool isCurried = shouldBuildCurryThunk(choice, /*baseIsInstance=*/false);
|
|
||||||
|
|
||||||
// If we have a concrete conformance, build a call to the witness.
|
|
||||||
//
|
|
||||||
// FIXME: This is awful. We should be able to handle this as a call to
|
|
||||||
// the protocol requirement with Self == the concrete type, and SILGen
|
|
||||||
// (or later) can devirtualize as appropriate.
|
|
||||||
auto conformance = checkConformance(baseTy, proto);
|
|
||||||
if (conformance.isConcrete()) {
|
|
||||||
if (auto witness = conformance.getConcrete()->getWitnessDecl(decl)) {
|
|
||||||
bool isMemberOperator = witness->getDeclContext()->isTypeContext();
|
|
||||||
|
|
||||||
if (!isMemberOperator || !isCurried) {
|
|
||||||
// The fullType was computed by substituting the protocol
|
|
||||||
// requirement so it always has a (Self) -> ... curried
|
|
||||||
// application. Strip it off if the witness was a top-level
|
|
||||||
// function.
|
|
||||||
Type refType;
|
|
||||||
if (isMemberOperator)
|
|
||||||
refType = adjustedFullType;
|
|
||||||
else
|
|
||||||
refType = adjustedFullType->castTo<AnyFunctionType>()->getResult();
|
|
||||||
|
|
||||||
// Build the AST for the call to the witness.
|
|
||||||
auto subMap = getOperatorSubstitutions(witness, refType);
|
|
||||||
if (subMap) {
|
|
||||||
ConcreteDeclRef witnessRef(witness, *subMap);
|
|
||||||
auto declRefExpr = new (ctx) DeclRefExpr(witnessRef, loc,
|
|
||||||
/*Implicit=*/false);
|
|
||||||
declRefExpr->setFunctionRefInfo(choice.getFunctionRefInfo());
|
|
||||||
cs.setType(declRefExpr, refType);
|
|
||||||
|
|
||||||
Expr *refExpr;
|
|
||||||
if (isMemberOperator) {
|
|
||||||
// If the operator is a type member, add the implicit
|
|
||||||
// (Self) -> ... call.
|
|
||||||
Expr *base =
|
|
||||||
TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy,
|
|
||||||
ctx);
|
|
||||||
cs.setType(base, MetatypeType::get(baseTy));
|
|
||||||
|
|
||||||
refExpr =
|
|
||||||
DotSyntaxCallExpr::create(ctx, declRefExpr, SourceLoc(),
|
|
||||||
Argument::unlabeled(base));
|
|
||||||
auto refType = adjustedFullType->castTo<FunctionType>()->getResult();
|
|
||||||
cs.setType(refExpr, refType);
|
|
||||||
} else {
|
|
||||||
refExpr = declRefExpr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return forceUnwrapIfExpected(refExpr, locator);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a reference to the member.
|
// Build a reference to the member.
|
||||||
Expr *base =
|
Expr *base =
|
||||||
TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy, ctx);
|
TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy, ctx);
|
||||||
|
|||||||
@@ -2,6 +2,9 @@
|
|||||||
// REQUIRES: executable_test
|
// REQUIRES: executable_test
|
||||||
|
|
||||||
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.
|
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.
|
||||||
|
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||||
|
We cannot use `SIMD` :( */
|
||||||
|
// XFAIL: *
|
||||||
|
|
||||||
import _Differentiation
|
import _Differentiation
|
||||||
import StdlibUnittest
|
import StdlibUnittest
|
||||||
|
|||||||
@@ -2,6 +2,9 @@
|
|||||||
// NOTE: Verify whether forward-mode differentiation crashes. It currently does.
|
// NOTE: Verify whether forward-mode differentiation crashes. It currently does.
|
||||||
// RUN: not --crash %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil %s
|
// RUN: not --crash %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil %s
|
||||||
// REQUIRES: executable_test
|
// REQUIRES: executable_test
|
||||||
|
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||||
|
We cannot use `Tracked<T>` :( */
|
||||||
|
// XFAIL: *
|
||||||
|
|
||||||
import StdlibUnittest
|
import StdlibUnittest
|
||||||
import DifferentiationUnittest
|
import DifferentiationUnittest
|
||||||
|
|||||||
@@ -9,6 +9,144 @@ import DifferentiationUnittest
|
|||||||
|
|
||||||
var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty")
|
var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty")
|
||||||
|
|
||||||
|
struct TangentSpace : AdditiveArithmetic {
|
||||||
|
let x, y: Float
|
||||||
|
}
|
||||||
|
|
||||||
|
extension TangentSpace : Differentiable {
|
||||||
|
typealias TangentVector = TangentSpace
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Space {
|
||||||
|
/// `x` is a computed property with a custom vjp.
|
||||||
|
var x: Float {
|
||||||
|
@differentiable(reverse)
|
||||||
|
get { storedX }
|
||||||
|
set { storedX = newValue }
|
||||||
|
}
|
||||||
|
|
||||||
|
@derivative(of: x)
|
||||||
|
func vjpX() -> (value: Float, pullback: (Float) -> TangentSpace) {
|
||||||
|
return (x, { v in TangentSpace(x: v, y: 0) } )
|
||||||
|
}
|
||||||
|
|
||||||
|
private var storedX: Float
|
||||||
|
|
||||||
|
@differentiable(reverse)
|
||||||
|
var y: Float
|
||||||
|
|
||||||
|
init(x: Float, y: Float) {
|
||||||
|
self.storedX = x
|
||||||
|
self.y = y
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension Space : Differentiable {
|
||||||
|
typealias TangentVector = TangentSpace
|
||||||
|
mutating func move(by offset: TangentSpace) {
|
||||||
|
x.move(by: offset.x)
|
||||||
|
y.move(by: offset.y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
E2EDifferentiablePropertyTests.test("computed property") {
|
||||||
|
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
|
||||||
|
return 2 * point.x
|
||||||
|
}
|
||||||
|
let expectedGrad = TangentSpace(x: 2, y: 0)
|
||||||
|
expectEqual(expectedGrad, actualGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
E2EDifferentiablePropertyTests.test("stored property") {
|
||||||
|
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
|
||||||
|
return 3 * point.y
|
||||||
|
}
|
||||||
|
let expectedGrad = TangentSpace(x: 0, y: 3)
|
||||||
|
expectEqual(expectedGrad, actualGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GenericMemberWrapper<T : Differentiable> : Differentiable {
|
||||||
|
// Stored property.
|
||||||
|
@differentiable(reverse)
|
||||||
|
var x: T
|
||||||
|
|
||||||
|
func vjpX() -> (T, (T.TangentVector) -> GenericMemberWrapper.TangentVector) {
|
||||||
|
return (x, { TangentVector(x: $0) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
E2EDifferentiablePropertyTests.test("generic stored property") {
|
||||||
|
let actualGrad = gradient(at: GenericMemberWrapper<Float>(x: 1)) { point in
|
||||||
|
return 2 * point.x
|
||||||
|
}
|
||||||
|
let expectedGrad = GenericMemberWrapper<Float>.TangentVector(x: 2)
|
||||||
|
expectEqual(expectedGrad, actualGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ProductSpaceSelfTangent : AdditiveArithmetic {
|
||||||
|
let x, y: Float
|
||||||
|
}
|
||||||
|
|
||||||
|
extension ProductSpaceSelfTangent : Differentiable {
|
||||||
|
typealias TangentVector = ProductSpaceSelfTangent
|
||||||
|
}
|
||||||
|
|
||||||
|
E2EDifferentiablePropertyTests.test("fieldwise product space, self tangent") {
|
||||||
|
let actualGrad = gradient(at: ProductSpaceSelfTangent(x: 0, y: 0)) { (point: ProductSpaceSelfTangent) -> Float in
|
||||||
|
return 5 * point.y
|
||||||
|
}
|
||||||
|
let expectedGrad = ProductSpaceSelfTangent(x: 0, y: 5)
|
||||||
|
expectEqual(expectedGrad, actualGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ProductSpaceOtherTangentTangentSpace : AdditiveArithmetic {
|
||||||
|
let x, y: Float
|
||||||
|
}
|
||||||
|
|
||||||
|
extension ProductSpaceOtherTangentTangentSpace : Differentiable {
|
||||||
|
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ProductSpaceOtherTangent {
|
||||||
|
var x, y: Float
|
||||||
|
}
|
||||||
|
|
||||||
|
extension ProductSpaceOtherTangent : Differentiable {
|
||||||
|
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
|
||||||
|
mutating func move(by offset: ProductSpaceOtherTangentTangentSpace) {
|
||||||
|
x.move(by: offset.x)
|
||||||
|
y.move(by: offset.y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
E2EDifferentiablePropertyTests.test("fieldwise product space, other tangent") {
|
||||||
|
let actualGrad = gradient(
|
||||||
|
at: ProductSpaceOtherTangent(x: 0, y: 0)
|
||||||
|
) { (point: ProductSpaceOtherTangent) -> Float in
|
||||||
|
return 7 * point.y
|
||||||
|
}
|
||||||
|
let expectedGrad = ProductSpaceOtherTangentTangentSpace(x: 0, y: 7)
|
||||||
|
expectEqual(expectedGrad, actualGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
E2EDifferentiablePropertyTests.test("computed property") {
|
||||||
|
struct TF_544 : Differentiable {
|
||||||
|
var value: Float
|
||||||
|
@differentiable(reverse)
|
||||||
|
var computed: Float {
|
||||||
|
get { value }
|
||||||
|
set { value = newValue }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let actualGrad = gradient(at: TF_544(value: 2.4)) { x in
|
||||||
|
return x.computed * x.computed
|
||||||
|
}
|
||||||
|
let expectedGrad = TF_544.TangentVector(value: 4.8)
|
||||||
|
expectEqual(expectedGrad, actualGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||||
|
We cannot use `Tracked<T>` :(
|
||||||
struct TangentSpace : AdditiveArithmetic {
|
struct TangentSpace : AdditiveArithmetic {
|
||||||
let x, y: Tracked<Float>
|
let x, y: Tracked<Float>
|
||||||
}
|
}
|
||||||
@@ -144,5 +282,6 @@ E2EDifferentiablePropertyTests.testWithLeakChecking("computed property") {
|
|||||||
let expectedGrad = TF_544.TangentVector(value: 4.8)
|
let expectedGrad = TF_544.TangentVector(value: 4.8)
|
||||||
expectEqual(expectedGrad, actualGrad)
|
expectEqual(expectedGrad, actualGrad)
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
runAllTests()
|
runAllTests()
|
||||||
|
|||||||
@@ -8,19 +8,39 @@ var ExistentialTests = TestSuite("Existential")
|
|||||||
|
|
||||||
protocol A {
|
protocol A {
|
||||||
@differentiable(reverse, wrt: x)
|
@differentiable(reverse, wrt: x)
|
||||||
func a(_ x: Tracked<Float>) -> Tracked<Float>
|
func a(_ x: Float) -> Float
|
||||||
}
|
}
|
||||||
func b(g: A) -> Tracked<Float> {
|
func b(g: A) -> Float {
|
||||||
return gradient(at: 3) { x in g.a(x) }
|
return gradient(at: 3) { x in g.a(x) }
|
||||||
}
|
}
|
||||||
|
|
||||||
struct B : A {
|
struct B : A {
|
||||||
@differentiable(reverse, wrt: x)
|
@differentiable(reverse, wrt: x)
|
||||||
func a(_ x: Tracked<Float>) -> Tracked<Float> { return x * 5 }
|
func a(_ x: Float) -> Float { return x * 5 }
|
||||||
}
|
}
|
||||||
|
|
||||||
ExistentialTests.testWithLeakChecking("Existential method VJP") {
|
ExistentialTests.test("Existential method VJP-Tracked") {
|
||||||
expectEqual(5.0, b(g: B()))
|
expectEqual(5.0, b(g: B()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||||
|
We cannot use `Tracked<T>` :(
|
||||||
|
protocol ATracked {
|
||||||
|
@differentiable(reverse, wrt: x)
|
||||||
|
func a(_ x: Tracked<Float>) -> Tracked<Float>
|
||||||
|
}
|
||||||
|
func b(g: ATracked) -> Tracked<Float> {
|
||||||
|
return gradient(at: 3) { x in g.a(x) }
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BTracked : ATracked {
|
||||||
|
@differentiable(reverse, wrt: x)
|
||||||
|
func a(_ x: Tracked<Float>) -> Tracked<Float> { return x * 5 }
|
||||||
|
}
|
||||||
|
|
||||||
|
ExistentialTests.testWithLeakChecking("Existential method VJP-Tracked") {
|
||||||
|
expectEqual(5.0, b(g: BTracked()))
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
runAllTests()
|
runAllTests()
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
|
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
|
||||||
// REQUIRES: executable_test
|
// REQUIRES: executable_test
|
||||||
|
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||||
|
We cannot use `SIMD` :( */
|
||||||
|
// XFAIL: *
|
||||||
|
|
||||||
import StdlibUnittest
|
import StdlibUnittest
|
||||||
import DifferentiationUnittest
|
import DifferentiationUnittest
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
|
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
|
||||||
// REQUIRES: executable_test
|
// REQUIRES: executable_test
|
||||||
|
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||||
|
We cannot use `Tracked<T>` :( */
|
||||||
|
// XFAIL: *
|
||||||
|
|
||||||
import StdlibUnittest
|
import StdlibUnittest
|
||||||
import DifferentiationUnittest
|
import DifferentiationUnittest
|
||||||
|
|||||||
@@ -9,50 +9,50 @@ var MethodTests = TestSuite("Method")
|
|||||||
// ==== Tests with generated adjoint ====
|
// ==== Tests with generated adjoint ====
|
||||||
|
|
||||||
struct Parameter : Equatable {
|
struct Parameter : Equatable {
|
||||||
private let storedX: Tracked<Float>
|
private let storedX: Float
|
||||||
@differentiable(reverse, wrt: (self))
|
@differentiable(reverse, wrt: (self))
|
||||||
var x: Tracked<Float> {
|
var x: Float {
|
||||||
return storedX
|
return storedX
|
||||||
}
|
}
|
||||||
|
|
||||||
init(x: Tracked<Float>) {
|
init(x: Float) {
|
||||||
storedX = x
|
storedX = x
|
||||||
}
|
}
|
||||||
|
|
||||||
@derivative(of: x)
|
@derivative(of: x)
|
||||||
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Parameter) {
|
func vjpX() -> (value: Float, pullback: (Float) -> Parameter) {
|
||||||
return (x, { dx in Parameter(x: dx) } )
|
return (x, { dx in Parameter(x: dx) } )
|
||||||
}
|
}
|
||||||
|
|
||||||
@derivative(of: x)
|
@derivative(of: x)
|
||||||
func jvpX() -> (value: Tracked<Float>, differential: (Parameter) -> Tracked<Float>) {
|
func jvpX() -> (value: Float, differential: (Parameter) -> Float) {
|
||||||
return (x, { $0.x })
|
return (x, { $0.x })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extension Parameter {
|
extension Parameter {
|
||||||
func squared() -> Tracked<Float> {
|
func squared() -> Float {
|
||||||
return x * x
|
return x * x
|
||||||
}
|
}
|
||||||
|
|
||||||
static func squared(p: Parameter) -> Tracked<Float> {
|
static func squared(p: Parameter) -> Float {
|
||||||
return p.x * p.x
|
return p.x * p.x
|
||||||
}
|
}
|
||||||
|
|
||||||
func multiplied(with other: Tracked<Float>) -> Tracked<Float> {
|
func multiplied(with other: Float) -> Float {
|
||||||
return x * other
|
return x * other
|
||||||
}
|
}
|
||||||
|
|
||||||
static func * (_ a: Parameter, _ b: Parameter) -> Tracked<Float> {
|
static func * (_ a: Parameter, _ b: Parameter) -> Float {
|
||||||
return a.x * b.x
|
return a.x * b.x
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extension Parameter : Differentiable, AdditiveArithmetic {
|
extension Parameter : Differentiable, AdditiveArithmetic {
|
||||||
typealias TangentVector = Parameter
|
typealias TangentVector = Parameter
|
||||||
typealias Scalar = Tracked<Float>
|
typealias Scalar = Float
|
||||||
typealias Shape = ()
|
typealias Shape = ()
|
||||||
init(repeating repeatedValue: Tracked<Float>, shape: ()) {
|
init(repeating repeatedValue: Float, shape: ()) {
|
||||||
self.init(x: repeatedValue)
|
self.init(x: repeatedValue)
|
||||||
}
|
}
|
||||||
static func + (lhs: Parameter, rhs: Parameter) -> Parameter {
|
static func + (lhs: Parameter, rhs: Parameter) -> Parameter {
|
||||||
@@ -67,37 +67,185 @@ extension Parameter : Differentiable, AdditiveArithmetic {
|
|||||||
static var zero: Parameter { return Parameter(x: 0) }
|
static var zero: Parameter { return Parameter(x: 0) }
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking(
|
MethodTests.test(
|
||||||
"instance method with generated adjoint, called from differentiated func"
|
"instance method with generated adjoint, called from differentiated func"
|
||||||
) {
|
) {
|
||||||
func f(_ p: Parameter) -> Tracked<Float> {
|
func f(_ p: Parameter) -> Float {
|
||||||
return 100 * p.squared()
|
return 100 * p.squared()
|
||||||
}
|
}
|
||||||
expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f))
|
expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f))
|
||||||
expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f))
|
expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MethodTests.test(
|
||||||
|
"instance method with generated adjoint, differentiated directly"
|
||||||
|
) {
|
||||||
|
// This is our current syntax for taking gradients of instance methods
|
||||||
|
// directly. If/when we develop nicer syntax for this, change this test.
|
||||||
|
func g(p: Parameter) -> Float { p.squared() }
|
||||||
|
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: g))
|
||||||
|
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: g))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("instance method with generated adjoint, wrt only self") {
|
||||||
|
func f(_ p: Parameter) -> Float {
|
||||||
|
return 100 * p.multiplied(with: 200)
|
||||||
|
}
|
||||||
|
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
||||||
|
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("instance method with generated adjoint, wrt only non-self") {
|
||||||
|
func f(_ other: Float) -> Float {
|
||||||
|
return 100 * Parameter(x: 200).multiplied(with: other)
|
||||||
|
}
|
||||||
|
expectEqual(100 * 200, gradient(at: 1, of: f))
|
||||||
|
expectEqual(100 * 200, gradient(at: 2, of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test(
|
||||||
|
"instance method with generated adjoint, wrt self and non-self"
|
||||||
|
) {
|
||||||
|
expectEqual(
|
||||||
|
(Parameter(x: 100), 200), gradient(at: Parameter(x: 200), 100) { $0.multiplied(with: $1) })
|
||||||
|
expectEqual(
|
||||||
|
(Parameter(x: 200), 100), gradient(at: Parameter(x: 100), 200) { $0.multiplied(with: $1) })
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test(
|
||||||
|
"static method with generated adjoint, called from differentiated func"
|
||||||
|
) {
|
||||||
|
func f(_ p: Parameter) -> Float {
|
||||||
|
return 100 * Parameter.squared(p: p)
|
||||||
|
}
|
||||||
|
expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f))
|
||||||
|
expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test(
|
||||||
|
"static method with generated adjoint, differentiated directly"
|
||||||
|
) {
|
||||||
|
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: Parameter.squared))
|
||||||
|
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: Parameter.squared))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("static method with generated adjoint, wrt only first param") {
|
||||||
|
func f(_ p: Parameter) -> Float {
|
||||||
|
return 100 * (p * Parameter(x: 200))
|
||||||
|
}
|
||||||
|
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
||||||
|
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("static method with generated adjoint, wrt only second param") {
|
||||||
|
func f(_ p: Parameter) -> Float {
|
||||||
|
return 100 * (Parameter(x: 200) * p)
|
||||||
|
}
|
||||||
|
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
||||||
|
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("static method with generated adjoint, wrt all params") {
|
||||||
|
func g(a: Parameter, b: Parameter) -> Float { a * b }
|
||||||
|
expectEqual((Parameter(x: 100), Parameter(x: 200)),
|
||||||
|
gradient(at: Parameter(x: 200), Parameter(x: 100), of: g))
|
||||||
|
expectEqual((Parameter(x: 200), Parameter(x: 100)),
|
||||||
|
gradient(at: Parameter(x: 100), Parameter(x: 200), of: g))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||||
|
We cannot use `Tracked<T>` :(
|
||||||
|
struct ParameterTracked : Equatable {
|
||||||
|
private let storedX: Tracked<Float>
|
||||||
|
@differentiable(reverse, wrt: (self))
|
||||||
|
var x: Tracked<Float> {
|
||||||
|
return storedX
|
||||||
|
}
|
||||||
|
|
||||||
|
init(x: Tracked<Float>) {
|
||||||
|
storedX = x
|
||||||
|
}
|
||||||
|
|
||||||
|
@derivative(of: x)
|
||||||
|
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> ParameterTracked) {
|
||||||
|
return (x, { dx in ParameterTracked(x: dx) } )
|
||||||
|
}
|
||||||
|
|
||||||
|
@derivative(of: x)
|
||||||
|
func jvpX() -> (value: Tracked<Float>, differential: (ParameterTracked) -> Tracked<Float>) {
|
||||||
|
return (x, { $0.x })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension ParameterTracked {
|
||||||
|
func squared() -> Tracked<Float> {
|
||||||
|
return x * x
|
||||||
|
}
|
||||||
|
|
||||||
|
static func squared(p: ParameterTracked) -> Tracked<Float> {
|
||||||
|
return p.x * p.x
|
||||||
|
}
|
||||||
|
|
||||||
|
func multiplied(with other: Tracked<Float>) -> Tracked<Float> {
|
||||||
|
return x * other
|
||||||
|
}
|
||||||
|
|
||||||
|
static func * (_ a: ParameterTracked, _ b: ParameterTracked) -> Tracked<Float> {
|
||||||
|
return a.x * b.x
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension ParameterTracked : Differentiable, AdditiveArithmetic {
|
||||||
|
typealias TangentVector = ParameterTracked
|
||||||
|
typealias Scalar = Tracked<Float>
|
||||||
|
typealias Shape = ()
|
||||||
|
init(repeating repeatedValue: Tracked<Float>, shape: ()) {
|
||||||
|
self.init(x: repeatedValue)
|
||||||
|
}
|
||||||
|
static func + (lhs: ParameterTracked, rhs: ParameterTracked) -> ParameterTracked {
|
||||||
|
return ParameterTracked(x: lhs.x + rhs.x)
|
||||||
|
}
|
||||||
|
static func - (lhs: ParameterTracked, rhs: ParameterTracked) -> ParameterTracked {
|
||||||
|
return ParameterTracked(x: lhs.x - rhs.x)
|
||||||
|
}
|
||||||
|
static func * (lhs: Scalar, rhs: ParameterTracked) -> ParameterTracked {
|
||||||
|
return ParameterTracked(x: lhs * rhs.x)
|
||||||
|
}
|
||||||
|
static var zero: ParameterTracked { return ParameterTracked(x: 0) }
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.testWithLeakChecking(
|
||||||
|
"instance method with generated adjoint, called from differentiated func"
|
||||||
|
) {
|
||||||
|
func f(_ p: ParameterTracked) -> Tracked<Float> {
|
||||||
|
return 100 * p.squared()
|
||||||
|
}
|
||||||
|
expectEqual(ParameterTracked(x: 4 * 100), gradient(at: ParameterTracked(x: 2), of: f))
|
||||||
|
expectEqual(ParameterTracked(x: 40 * 100), gradient(at: ParameterTracked(x: 20), of: f))
|
||||||
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking(
|
MethodTests.testWithLeakChecking(
|
||||||
"instance method with generated adjoint, differentiated directly"
|
"instance method with generated adjoint, differentiated directly"
|
||||||
) {
|
) {
|
||||||
// This is our current syntax for taking gradients of instance methods
|
// This is our current syntax for taking gradients of instance methods
|
||||||
// directly. If/when we develop nicer syntax for this, change this test.
|
// directly. If/when we develop nicer syntax for this, change this test.
|
||||||
func g(p: Parameter) -> Tracked<Float> { p.squared() }
|
func g(p: ParameterTracked) -> Tracked<Float> { p.squared() }
|
||||||
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: g))
|
expectEqual(ParameterTracked(x: 4), gradient(at: ParameterTracked(x: 2), of: g))
|
||||||
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: g))
|
expectEqual(ParameterTracked(x: 40), gradient(at: ParameterTracked(x: 20), of: g))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only self") {
|
MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only self") {
|
||||||
func f(_ p: Parameter) -> Tracked<Float> {
|
func f(_ p: ParameterTracked) -> Tracked<Float> {
|
||||||
return 100 * p.multiplied(with: 200)
|
return 100 * p.multiplied(with: 200)
|
||||||
}
|
}
|
||||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f))
|
||||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only non-self") {
|
MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only non-self") {
|
||||||
func f(_ other: Tracked<Float>) -> Tracked<Float> {
|
func f(_ other: Tracked<Float>) -> Tracked<Float> {
|
||||||
return 100 * Parameter(x: 200).multiplied(with: other)
|
return 100 * ParameterTracked(x: 200).multiplied(with: other)
|
||||||
}
|
}
|
||||||
expectEqual(100 * 200, gradient(at: 1, of: f))
|
expectEqual(100 * 200, gradient(at: 1, of: f))
|
||||||
expectEqual(100 * 200, gradient(at: 2, of: f))
|
expectEqual(100 * 200, gradient(at: 2, of: f))
|
||||||
@@ -107,51 +255,52 @@ MethodTests.testWithLeakChecking(
|
|||||||
"instance method with generated adjoint, wrt self and non-self"
|
"instance method with generated adjoint, wrt self and non-self"
|
||||||
) {
|
) {
|
||||||
expectEqual(
|
expectEqual(
|
||||||
(Parameter(x: 100), 200), gradient(at: Parameter(x: 200), 100) { $0.multiplied(with: $1) })
|
(ParameterTracked(x: 100), 200), gradient(at: ParameterTracked(x: 200), 100) { $0.multiplied(with: $1) })
|
||||||
expectEqual(
|
expectEqual(
|
||||||
(Parameter(x: 200), 100), gradient(at: Parameter(x: 100), 200) { $0.multiplied(with: $1) })
|
(ParameterTracked(x: 200), 100), gradient(at: ParameterTracked(x: 100), 200) { $0.multiplied(with: $1) })
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking(
|
MethodTests.testWithLeakChecking(
|
||||||
"static method with generated adjoint, called from differentiated func"
|
"static method with generated adjoint, called from differentiated func"
|
||||||
) {
|
) {
|
||||||
func f(_ p: Parameter) -> Tracked<Float> {
|
func f(_ p: ParameterTracked) -> Tracked<Float> {
|
||||||
return 100 * Parameter.squared(p: p)
|
return 100 * ParameterTracked.squared(p: p)
|
||||||
}
|
}
|
||||||
expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f))
|
expectEqual(ParameterTracked(x: 4 * 100), gradient(at: ParameterTracked(x: 2), of: f))
|
||||||
expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f))
|
expectEqual(ParameterTracked(x: 40 * 100), gradient(at: ParameterTracked(x: 20), of: f))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking(
|
MethodTests.testWithLeakChecking(
|
||||||
"static method with generated adjoint, differentiated directly"
|
"static method with generated adjoint, differentiated directly"
|
||||||
) {
|
) {
|
||||||
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: Parameter.squared))
|
expectEqual(ParameterTracked(x: 4), gradient(at: ParameterTracked(x: 2), of: ParameterTracked.squared))
|
||||||
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: Parameter.squared))
|
expectEqual(ParameterTracked(x: 40), gradient(at: ParameterTracked(x: 20), of: ParameterTracked.squared))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only first param") {
|
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only first param") {
|
||||||
func f(_ p: Parameter) -> Tracked<Float> {
|
func f(_ p: ParameterTracked) -> Tracked<Float> {
|
||||||
return 100 * (p * Parameter(x: 200))
|
return 100 * (p * ParameterTracked(x: 200))
|
||||||
}
|
}
|
||||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f))
|
||||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only second param") {
|
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only second param") {
|
||||||
func f(_ p: Parameter) -> Tracked<Float> {
|
func f(_ p: ParameterTracked) -> Tracked<Float> {
|
||||||
return 100 * (Parameter(x: 200) * p)
|
return 100 * (ParameterTracked(x: 200) * p)
|
||||||
}
|
}
|
||||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f))
|
||||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt all params") {
|
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt all params") {
|
||||||
func g(a: Parameter, b: Parameter) -> Tracked<Float> { a * b }
|
func g(a: ParameterTracked, b: ParameterTracked) -> Tracked<Float> { a * b }
|
||||||
expectEqual((Parameter(x: 100), Parameter(x: 200)),
|
expectEqual((ParameterTracked(x: 100), ParameterTracked(x: 200)),
|
||||||
gradient(at: Parameter(x: 200), Parameter(x: 100), of: g))
|
gradient(at: ParameterTracked(x: 200), ParameterTracked(x: 100), of: g))
|
||||||
expectEqual((Parameter(x: 200), Parameter(x: 100)),
|
expectEqual((ParameterTracked(x: 200), ParameterTracked(x: 100)),
|
||||||
gradient(at: Parameter(x: 100), Parameter(x: 200), of: g))
|
gradient(at: ParameterTracked(x: 100), ParameterTracked(x: 200), of: g))
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
// ==== Tests with custom adjoint ====
|
// ==== Tests with custom adjoint ====
|
||||||
|
|
||||||
@@ -174,27 +323,27 @@ struct DiffWrtSelf : Differentiable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct CustomParameter : Equatable {
|
struct CustomParameter : Equatable {
|
||||||
let storedX: Tracked<Float>
|
let storedX: Float
|
||||||
@differentiable(reverse, wrt: (self))
|
@differentiable(reverse, wrt: (self))
|
||||||
var x: Tracked<Float> {
|
var x: Float {
|
||||||
return storedX
|
return storedX
|
||||||
}
|
}
|
||||||
|
|
||||||
init(x: Tracked<Float>) {
|
init(x: Float) {
|
||||||
storedX = x
|
storedX = x
|
||||||
}
|
}
|
||||||
|
|
||||||
@derivative(of: x)
|
@derivative(of: x)
|
||||||
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) {
|
func vjpX() -> (value: Float, pullback: (Float) -> CustomParameter) {
|
||||||
return (x, { dx in CustomParameter(x: dx) })
|
return (x, { dx in CustomParameter(x: dx) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extension CustomParameter : Differentiable, AdditiveArithmetic {
|
extension CustomParameter : Differentiable, AdditiveArithmetic {
|
||||||
typealias TangentVector = CustomParameter
|
typealias TangentVector = CustomParameter
|
||||||
typealias Scalar = Tracked<Float>
|
typealias Scalar = Float
|
||||||
typealias Shape = ()
|
typealias Shape = ()
|
||||||
init(repeating repeatedValue: Tracked<Float>, shape: ()) {
|
init(repeating repeatedValue: Float, shape: ()) {
|
||||||
self.init(x: repeatedValue)
|
self.init(x: repeatedValue)
|
||||||
}
|
}
|
||||||
static func + (lhs: CustomParameter, rhs: CustomParameter) -> CustomParameter {
|
static func + (lhs: CustomParameter, rhs: CustomParameter) -> CustomParameter {
|
||||||
@@ -209,38 +358,256 @@ extension CustomParameter : Differentiable, AdditiveArithmetic {
|
|||||||
static var zero: CustomParameter { return CustomParameter(x: 0) }
|
static var zero: CustomParameter { return CustomParameter(x: 0) }
|
||||||
}
|
}
|
||||||
|
|
||||||
extension Tracked where T : FloatingPoint {
|
extension Float {
|
||||||
func clamped(to limits: ClosedRange<Tracked<T>>) -> Tracked<T> {
|
func clamped(to limits: ClosedRange<Float>) -> Float {
|
||||||
return min(max(self, limits.lowerBound), limits.upperBound)
|
return min(max(self, limits.lowerBound), limits.upperBound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extension CustomParameter {
|
extension CustomParameter {
|
||||||
@differentiable(reverse, wrt: (self))
|
@differentiable(reverse, wrt: (self))
|
||||||
func squared() -> Tracked<Float> {
|
func squared() -> Float {
|
||||||
return x * x
|
return x * x
|
||||||
}
|
}
|
||||||
|
|
||||||
@derivative(of: squared)
|
@derivative(of: squared)
|
||||||
func dSquared() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) {
|
func dSquared() -> (value: Float, pullback: (Float) -> CustomParameter) {
|
||||||
return (squared(), { [x] v in CustomParameter(x: (2 * x).clamped(to: -10.0...10.0) * v) })
|
return (squared(), { [x] v in CustomParameter(x: (2 * x).clamped(to: -10.0...10.0) * v) })
|
||||||
}
|
}
|
||||||
|
|
||||||
@differentiable(reverse)
|
@differentiable(reverse)
|
||||||
static func squared(p: CustomParameter) -> Tracked<Float> {
|
static func squared(p: CustomParameter) -> Float {
|
||||||
return p.x * p.x
|
return p.x * p.x
|
||||||
}
|
}
|
||||||
|
|
||||||
@derivative(of: squared)
|
@derivative(of: squared)
|
||||||
static func dSquared(
|
static func dSquared(
|
||||||
_ p: CustomParameter
|
_ p: CustomParameter
|
||||||
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) {
|
) -> (value: Float, pullback: (Float) -> CustomParameter) {
|
||||||
return (p.x * p.x, { v in CustomParameter(x: (2 * p.x).clamped(to: -10.0...10.0) * v) })
|
return (p.x * p.x, { v in CustomParameter(x: (2 * p.x).clamped(to: -10.0...10.0) * v) })
|
||||||
}
|
}
|
||||||
|
|
||||||
// There is currently no way to define multiple custom VJPs wrt different
|
// There is currently no way to define multiple custom VJPs wrt different
|
||||||
// parameters on the same func, so we define a copy of this func per adjoint.
|
// parameters on the same func, so we define a copy of this func per adjoint.
|
||||||
|
|
||||||
|
@differentiable(reverse, wrt: (self, other))
|
||||||
|
func multiplied(with other: Float) -> Float {
|
||||||
|
return x * other
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse, wrt: (other))
|
||||||
|
func multiplied_constSelf(with other: Float) -> Float {
|
||||||
|
return x * other
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse, wrt: (self))
|
||||||
|
func multiplied_constOther(with other: Float) -> Float {
|
||||||
|
return x * other
|
||||||
|
}
|
||||||
|
|
||||||
|
@derivative(of: multiplied)
|
||||||
|
func dMultiplied_wrtAll(
|
||||||
|
with other: Float
|
||||||
|
) -> (value: Float, pullback: (Float) -> (CustomParameter, Float)) {
|
||||||
|
return (multiplied(with: other),
|
||||||
|
{ [x] v in (CustomParameter(x: other.clamped(to: -10.0...10.0) * v),
|
||||||
|
x.clamped(to: -10.0...10.0) * v) })
|
||||||
|
}
|
||||||
|
|
||||||
|
@derivative(of: multiplied_constSelf, wrt: other)
|
||||||
|
func dMultiplied_wrtOther(
|
||||||
|
with other: Float
|
||||||
|
) -> (value: Float, pullback: (Float) -> Float) {
|
||||||
|
let (r, pb) = dMultiplied_wrtAll(with: other)
|
||||||
|
return (r, { v in pb(v).1 })
|
||||||
|
}
|
||||||
|
|
||||||
|
@derivative(of: multiplied_constOther, wrt: self)
|
||||||
|
func dMultiplied_wrtSelf(
|
||||||
|
with other: Float
|
||||||
|
) -> (value: Float, pullback: (Float) -> CustomParameter) {
|
||||||
|
let (r, pb) = dMultiplied_wrtAll(with: other)
|
||||||
|
return (r, { v in pb(v).0 })
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse)
|
||||||
|
static func multiply(_ lhs: CustomParameter, _ rhs: CustomParameter)
|
||||||
|
-> Float {
|
||||||
|
return lhs.x * rhs.x
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse, wrt: (rhs))
|
||||||
|
static func multiply_constLhs(_ lhs: CustomParameter, _ rhs: CustomParameter) -> Float {
|
||||||
|
return lhs.x * rhs.x
|
||||||
|
}
|
||||||
|
|
||||||
|
@derivative(of: multiply)
|
||||||
|
static func dMultiply_wrtAll(_ lhs: CustomParameter,_ rhs: CustomParameter)
|
||||||
|
-> (value: Float, pullback: (Float) -> (CustomParameter, CustomParameter)) {
|
||||||
|
let result = multiply(lhs, rhs)
|
||||||
|
return (result, { v in (CustomParameter(x: rhs.x.clamped(to: -10.0...10.0) * v),
|
||||||
|
CustomParameter(x: lhs.x.clamped(to: -10.0...10.0) * v)) })
|
||||||
|
}
|
||||||
|
|
||||||
|
@derivative(of: multiply_constLhs, wrt: rhs)
|
||||||
|
static func dMultiply_wrtRhs(_ lhs: CustomParameter, _ rhs: CustomParameter)
|
||||||
|
-> (value: Float, pullback: (Float) -> CustomParameter) {
|
||||||
|
let (r, pb) = dMultiply_wrtAll(lhs, rhs)
|
||||||
|
return (r, { v in pb(v).1 })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test(
|
||||||
|
"instance method with custom adjoint, called from differentiated func"
|
||||||
|
) {
|
||||||
|
func f(_ p: CustomParameter) -> Float {
|
||||||
|
return 100 * p.squared()
|
||||||
|
}
|
||||||
|
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f))
|
||||||
|
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("instance method with generated adjoint, differentiated directly") {
|
||||||
|
// This is our current syntax for taking gradients of instance methods
|
||||||
|
// directly. If/when we develop nicer syntax for this, change this test.
|
||||||
|
func g(p: CustomParameter) -> Float { p.squared() }
|
||||||
|
expectEqual(CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: g))
|
||||||
|
expectEqual(CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: g))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("static method with custom adjoint, called from differentiated func") {
|
||||||
|
func f(_ p: CustomParameter) -> Float {
|
||||||
|
return 100 * CustomParameter.squared(p: p)
|
||||||
|
}
|
||||||
|
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f))
|
||||||
|
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("static method with custom adjoint, differentiated directly") {
|
||||||
|
expectEqual(
|
||||||
|
CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: CustomParameter.squared))
|
||||||
|
expectEqual(
|
||||||
|
CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: CustomParameter.squared))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("instance method with custom adjoint, wrt only self") {
|
||||||
|
func f(_ p: CustomParameter) -> Float {
|
||||||
|
return 100 * p.multiplied_constOther(with: 200)
|
||||||
|
}
|
||||||
|
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
||||||
|
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("instance method with custom adjoint, wrt only non-self") {
|
||||||
|
func f(_ other: Float) -> Float {
|
||||||
|
return 100 * CustomParameter(x: 200).multiplied_constSelf(with: other)
|
||||||
|
}
|
||||||
|
expectEqual(100 * 10, gradient(at: 1, of: f))
|
||||||
|
expectEqual(100 * 10, gradient(at: 2, of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("instance method with custom adjoint, wrt self and non-self") {
|
||||||
|
func g(p: CustomParameter, o: Float) -> Float { p.multiplied(with: o) }
|
||||||
|
expectEqual((CustomParameter(x: 5), 10), gradient(at: CustomParameter(x: 100), 5, of: g))
|
||||||
|
expectEqual((CustomParameter(x: 10), 5), gradient(at: CustomParameter(x: 5), 100, of: g))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("static method with custom adjoint, wrt only lhs") {
|
||||||
|
func f(_ p: CustomParameter) -> Float {
|
||||||
|
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p)
|
||||||
|
}
|
||||||
|
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
||||||
|
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("static method with custom adjoint, wrt only rhs") {
|
||||||
|
func f(_ p: CustomParameter) -> Float {
|
||||||
|
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p)
|
||||||
|
}
|
||||||
|
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
||||||
|
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTests.test("static method with custom adjoint, wrt all") {
|
||||||
|
func f(_ a: CustomParameter, _ b: CustomParameter) -> Float {
|
||||||
|
return CustomParameter.multiply(a, b)
|
||||||
|
}
|
||||||
|
expectEqual((CustomParameter(x: 5), CustomParameter(x: 10)),
|
||||||
|
gradient(at: CustomParameter(x: 100), CustomParameter(x: 5), of: f))
|
||||||
|
expectEqual((CustomParameter(x: 10), CustomParameter(x: 5)),
|
||||||
|
gradient(at: CustomParameter(x: 5), CustomParameter(x: 100), of: f))
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||||
|
We cannot use `Tracked<T>` :(
|
||||||
|
struct CustomParameterTracked : Equatable {
|
||||||
|
let storedX: Tracked<Float>
|
||||||
|
@differentiable(reverse, wrt: (self))
|
||||||
|
var x: Tracked<Float> {
|
||||||
|
return storedX
|
||||||
|
}
|
||||||
|
|
||||||
|
init(x: Tracked<Float>) {
|
||||||
|
storedX = x
|
||||||
|
}
|
||||||
|
|
||||||
|
@derivative(of: x)
|
||||||
|
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
|
||||||
|
return (x, { dx in CustomParameterTracked(x: dx) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension CustomParameterTracked : Differentiable, AdditiveArithmetic {
|
||||||
|
typealias TangentVector = CustomParameterTracked
|
||||||
|
typealias Scalar = Tracked<Float>
|
||||||
|
typealias Shape = ()
|
||||||
|
init(repeating repeatedValue: Tracked<Float>, shape: ()) {
|
||||||
|
self.init(x: repeatedValue)
|
||||||
|
}
|
||||||
|
static func + (lhs: CustomParameterTracked, rhs: CustomParameterTracked) -> CustomParameterTracked {
|
||||||
|
return CustomParameterTracked(x: lhs.x + rhs.x)
|
||||||
|
}
|
||||||
|
static func - (lhs: CustomParameterTracked, rhs: CustomParameterTracked) -> CustomParameterTracked {
|
||||||
|
return CustomParameterTracked(x: lhs.x - rhs.x)
|
||||||
|
}
|
||||||
|
static func * (lhs: Scalar, rhs: CustomParameterTracked) -> CustomParameterTracked {
|
||||||
|
return CustomParameterTracked(x: lhs * rhs.x)
|
||||||
|
}
|
||||||
|
static var zero: CustomParameterTracked { return CustomParameterTracked(x: 0) }
|
||||||
|
}
|
||||||
|
|
||||||
|
extension Tracked where T : FloatingPoint {
|
||||||
|
func clamped(to limits: ClosedRange<Tracked<T>>) -> Tracked<T> {
|
||||||
|
return min(max(self, limits.lowerBound), limits.upperBound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension CustomParameterTracked {
|
||||||
|
@differentiable(reverse, wrt: (self))
|
||||||
|
func squared() -> Tracked<Float> {
|
||||||
|
return x * x
|
||||||
|
}
|
||||||
|
|
||||||
|
@derivative(of: squared)
|
||||||
|
func dSquared() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
|
||||||
|
return (squared(), { [x] v in CustomParameterTracked(x: (2 * x).clamped(to: -10.0...10.0) * v) })
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse)
|
||||||
|
static func squared(p: CustomParameterTracked) -> Tracked<Float> {
|
||||||
|
return p.x * p.x
|
||||||
|
}
|
||||||
|
|
||||||
|
@derivative(of: squared)
|
||||||
|
static func dSquared(
|
||||||
|
_ p: CustomParameterTracked
|
||||||
|
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
|
||||||
|
return (p.x * p.x, { v in CustomParameterTracked(x: (2 * p.x).clamped(to: -10.0...10.0) * v) })
|
||||||
|
}
|
||||||
|
|
||||||
|
// There is currently no way to define multiple custom VJPs wrt different
|
||||||
|
// parameters on the same func, so we define a copy of this func per adjoint.
|
||||||
|
|
||||||
@differentiable(reverse, wrt: (self, other))
|
@differentiable(reverse, wrt: (self, other))
|
||||||
func multiplied(with other: Tracked<Float>) -> Tracked<Float> {
|
func multiplied(with other: Tracked<Float>) -> Tracked<Float> {
|
||||||
return x * other
|
return x * other
|
||||||
@@ -259,9 +626,9 @@ extension CustomParameter {
|
|||||||
@derivative(of: multiplied)
|
@derivative(of: multiplied)
|
||||||
func dMultiplied_wrtAll(
|
func dMultiplied_wrtAll(
|
||||||
with other: Tracked<Float>
|
with other: Tracked<Float>
|
||||||
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameter, Tracked<Float>)) {
|
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameterTracked, Tracked<Float>)) {
|
||||||
return (multiplied(with: other),
|
return (multiplied(with: other),
|
||||||
{ [x] v in (CustomParameter(x: other.clamped(to: -10.0...10.0) * v),
|
{ [x] v in (CustomParameterTracked(x: other.clamped(to: -10.0...10.0) * v),
|
||||||
x.clamped(to: -10.0...10.0) * v) })
|
x.clamped(to: -10.0...10.0) * v) })
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -276,33 +643,33 @@ extension CustomParameter {
|
|||||||
@derivative(of: multiplied_constOther, wrt: self)
|
@derivative(of: multiplied_constOther, wrt: self)
|
||||||
func dMultiplied_wrtSelf(
|
func dMultiplied_wrtSelf(
|
||||||
with other: Tracked<Float>
|
with other: Tracked<Float>
|
||||||
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) {
|
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
|
||||||
let (r, pb) = dMultiplied_wrtAll(with: other)
|
let (r, pb) = dMultiplied_wrtAll(with: other)
|
||||||
return (r, { v in pb(v).0 })
|
return (r, { v in pb(v).0 })
|
||||||
}
|
}
|
||||||
|
|
||||||
@differentiable(reverse)
|
@differentiable(reverse)
|
||||||
static func multiply(_ lhs: CustomParameter, _ rhs: CustomParameter)
|
static func multiply(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked)
|
||||||
-> Tracked<Float> {
|
-> Tracked<Float> {
|
||||||
return lhs.x * rhs.x
|
return lhs.x * rhs.x
|
||||||
}
|
}
|
||||||
|
|
||||||
@differentiable(reverse, wrt: (rhs))
|
@differentiable(reverse, wrt: (rhs))
|
||||||
static func multiply_constLhs(_ lhs: CustomParameter, _ rhs: CustomParameter) -> Tracked<Float> {
|
static func multiply_constLhs(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked) -> Tracked<Float> {
|
||||||
return lhs.x * rhs.x
|
return lhs.x * rhs.x
|
||||||
}
|
}
|
||||||
|
|
||||||
@derivative(of: multiply)
|
@derivative(of: multiply)
|
||||||
static func dMultiply_wrtAll(_ lhs: CustomParameter,_ rhs: CustomParameter)
|
static func dMultiply_wrtAll(_ lhs: CustomParameterTracked,_ rhs: CustomParameterTracked)
|
||||||
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameter, CustomParameter)) {
|
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameterTracked, CustomParameterTracked)) {
|
||||||
let result = multiply(lhs, rhs)
|
let result = multiply(lhs, rhs)
|
||||||
return (result, { v in (CustomParameter(x: rhs.x.clamped(to: -10.0...10.0) * v),
|
return (result, { v in (CustomParameterTracked(x: rhs.x.clamped(to: -10.0...10.0) * v),
|
||||||
CustomParameter(x: lhs.x.clamped(to: -10.0...10.0) * v)) })
|
CustomParameterTracked(x: lhs.x.clamped(to: -10.0...10.0) * v)) })
|
||||||
}
|
}
|
||||||
|
|
||||||
@derivative(of: multiply_constLhs, wrt: rhs)
|
@derivative(of: multiply_constLhs, wrt: rhs)
|
||||||
static func dMultiply_wrtRhs(_ lhs: CustomParameter, _ rhs: CustomParameter)
|
static func dMultiply_wrtRhs(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked)
|
||||||
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) {
|
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
|
||||||
let (r, pb) = dMultiply_wrtAll(lhs, rhs)
|
let (r, pb) = dMultiply_wrtAll(lhs, rhs)
|
||||||
return (r, { v in pb(v).1 })
|
return (r, { v in pb(v).1 })
|
||||||
}
|
}
|
||||||
@@ -311,82 +678,83 @@ extension CustomParameter {
|
|||||||
MethodTests.testWithLeakChecking(
|
MethodTests.testWithLeakChecking(
|
||||||
"instance method with custom adjoint, called from differentiated func"
|
"instance method with custom adjoint, called from differentiated func"
|
||||||
) {
|
) {
|
||||||
func f(_ p: CustomParameter) -> Tracked<Float> {
|
func f(_ p: CustomParameterTracked) -> Tracked<Float> {
|
||||||
return 100 * p.squared()
|
return 100 * p.squared()
|
||||||
}
|
}
|
||||||
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f))
|
expectEqual(CustomParameterTracked(x: 4 * 100), gradient(at: CustomParameterTracked(x: 2), of: f))
|
||||||
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f))
|
expectEqual(CustomParameterTracked(x: 10 * 100), gradient(at: CustomParameterTracked(x: 20), of: f))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("instance method with generated adjoint, differentiated directly") {
|
MethodTests.testWithLeakChecking("instance method with generated adjoint, differentiated directly") {
|
||||||
// This is our current syntax for taking gradients of instance methods
|
// This is our current syntax for taking gradients of instance methods
|
||||||
// directly. If/when we develop nicer syntax for this, change this test.
|
// directly. If/when we develop nicer syntax for this, change this test.
|
||||||
func g(p: CustomParameter) -> Tracked<Float> { p.squared() }
|
func g(p: CustomParameterTracked) -> Tracked<Float> { p.squared() }
|
||||||
expectEqual(CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: g))
|
expectEqual(CustomParameterTracked(x: 4), gradient(at: CustomParameterTracked(x: 2), of: g))
|
||||||
expectEqual(CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: g))
|
expectEqual(CustomParameterTracked(x: 10), gradient(at: CustomParameterTracked(x: 20), of: g))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("static method with custom adjoint, called from differentiated func") {
|
MethodTests.testWithLeakChecking("static method with custom adjoint, called from differentiated func") {
|
||||||
func f(_ p: CustomParameter) -> Tracked<Float> {
|
func f(_ p: CustomParameterTracked) -> Tracked<Float> {
|
||||||
return 100 * CustomParameter.squared(p: p)
|
return 100 * CustomParameterTracked.squared(p: p)
|
||||||
}
|
}
|
||||||
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f))
|
expectEqual(CustomParameterTracked(x: 4 * 100), gradient(at: CustomParameterTracked(x: 2), of: f))
|
||||||
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f))
|
expectEqual(CustomParameterTracked(x: 10 * 100), gradient(at: CustomParameterTracked(x: 20), of: f))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("static method with custom adjoint, differentiated directly") {
|
MethodTests.testWithLeakChecking("static method with custom adjoint, differentiated directly") {
|
||||||
expectEqual(
|
expectEqual(
|
||||||
CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: CustomParameter.squared))
|
CustomParameterTracked(x: 4), gradient(at: CustomParameterTracked(x: 2), of: CustomParameterTracked.squared))
|
||||||
expectEqual(
|
expectEqual(
|
||||||
CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: CustomParameter.squared))
|
CustomParameterTracked(x: 10), gradient(at: CustomParameterTracked(x: 20), of: CustomParameterTracked.squared))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only self") {
|
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only self") {
|
||||||
func f(_ p: CustomParameter) -> Tracked<Float> {
|
func f(_ p: CustomParameterTracked) -> Tracked<Float> {
|
||||||
return 100 * p.multiplied_constOther(with: 200)
|
return 100 * p.multiplied_constOther(with: 200)
|
||||||
}
|
}
|
||||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f))
|
||||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only non-self") {
|
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only non-self") {
|
||||||
func f(_ other: Tracked<Float>) -> Tracked<Float> {
|
func f(_ other: Tracked<Float>) -> Tracked<Float> {
|
||||||
return 100 * CustomParameter(x: 200).multiplied_constSelf(with: other)
|
return 100 * CustomParameterTracked(x: 200).multiplied_constSelf(with: other)
|
||||||
}
|
}
|
||||||
expectEqual(100 * 10, gradient(at: 1, of: f))
|
expectEqual(100 * 10, gradient(at: 1, of: f))
|
||||||
expectEqual(100 * 10, gradient(at: 2, of: f))
|
expectEqual(100 * 10, gradient(at: 2, of: f))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt self and non-self") {
|
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt self and non-self") {
|
||||||
func g(p: CustomParameter, o: Tracked<Float>) -> Tracked<Float> { p.multiplied(with: o) }
|
func g(p: CustomParameterTracked, o: Tracked<Float>) -> Tracked<Float> { p.multiplied(with: o) }
|
||||||
expectEqual((CustomParameter(x: 5), 10), gradient(at: CustomParameter(x: 100), 5, of: g))
|
expectEqual((CustomParameterTracked(x: 5), 10), gradient(at: CustomParameterTracked(x: 100), 5, of: g))
|
||||||
expectEqual((CustomParameter(x: 10), 5), gradient(at: CustomParameter(x: 5), 100, of: g))
|
expectEqual((CustomParameterTracked(x: 10), 5), gradient(at: CustomParameterTracked(x: 5), 100, of: g))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only lhs") {
|
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only lhs") {
|
||||||
func f(_ p: CustomParameter) -> Tracked<Float> {
|
func f(_ p: CustomParameterTracked) -> Tracked<Float> {
|
||||||
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p)
|
return 100 * CustomParameterTracked.multiply_constLhs(CustomParameterTracked(x: 200), p)
|
||||||
}
|
}
|
||||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f))
|
||||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only rhs") {
|
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only rhs") {
|
||||||
func f(_ p: CustomParameter) -> Tracked<Float> {
|
func f(_ p: CustomParameterTracked) -> Tracked<Float> {
|
||||||
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p)
|
return 100 * CustomParameterTracked.multiply_constLhs(CustomParameterTracked(x: 200), p)
|
||||||
}
|
}
|
||||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f))
|
||||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f))
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt all") {
|
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt all") {
|
||||||
func f(_ a: CustomParameter, _ b: CustomParameter) -> Tracked<Float> {
|
func f(_ a: CustomParameterTracked, _ b: CustomParameterTracked) -> Tracked<Float> {
|
||||||
return CustomParameter.multiply(a, b)
|
return CustomParameterTracked.multiply(a, b)
|
||||||
}
|
}
|
||||||
expectEqual((CustomParameter(x: 5), CustomParameter(x: 10)),
|
expectEqual((CustomParameterTracked(x: 5), CustomParameterTracked(x: 10)),
|
||||||
gradient(at: CustomParameter(x: 100), CustomParameter(x: 5), of: f))
|
gradient(at: CustomParameterTracked(x: 100), CustomParameterTracked(x: 5), of: f))
|
||||||
expectEqual((CustomParameter(x: 10), CustomParameter(x: 5)),
|
expectEqual((CustomParameterTracked(x: 10), CustomParameterTracked(x: 5)),
|
||||||
gradient(at: CustomParameter(x: 5), CustomParameter(x: 100), of: f))
|
gradient(at: CustomParameterTracked(x: 5), CustomParameterTracked(x: 100), of: f))
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
runAllTests()
|
runAllTests()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// RUN: %target-run-simple-swift
|
1// RUN: %target-run-simple-swift
|
||||||
// REQUIRES: executable_test
|
// REQUIRES: executable_test
|
||||||
|
|
||||||
import StdlibUnittest
|
import StdlibUnittest
|
||||||
@@ -6,7 +6,19 @@ import DifferentiationUnittest
|
|||||||
|
|
||||||
var RepeatedCallsTests = TestSuite("RepeatedCalls")
|
var RepeatedCallsTests = TestSuite("RepeatedCalls")
|
||||||
|
|
||||||
RepeatedCallsTests.testWithLeakChecking("Repeat") {
|
RepeatedCallsTests.test("Repeat") {
|
||||||
|
func mul2(_ x: Float) -> Float {
|
||||||
|
return 2 * x
|
||||||
|
}
|
||||||
|
func mul4(_ x: Float) -> Float {
|
||||||
|
return mul2(mul2(x))
|
||||||
|
}
|
||||||
|
expectEqual(4, gradient(at: 0, of: mul4))
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||||
|
We cannot use `Tracked<T>` :(
|
||||||
|
RepeatedCallsTests.testWithLeakChecking("Repeat-Tracked") {
|
||||||
func mul2(_ x: Tracked<Float>) -> Tracked<Float> {
|
func mul2(_ x: Tracked<Float>) -> Tracked<Float> {
|
||||||
return 2 * x
|
return 2 * x
|
||||||
}
|
}
|
||||||
@@ -15,5 +27,6 @@ RepeatedCallsTests.testWithLeakChecking("Repeat") {
|
|||||||
}
|
}
|
||||||
expectEqual(4, gradient(at: 0, of: mul4))
|
expectEqual(4, gradient(at: 0, of: mul4))
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
runAllTests()
|
runAllTests()
|
||||||
|
|||||||
@@ -266,6 +266,9 @@ SimpleMathTests.test("TupleMutation") {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Tests TF-321.
|
// Tests TF-321.
|
||||||
|
|
||||||
|
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||||
|
We cannot use `Tracked<T>` :(
|
||||||
SimpleMathTests.test("TupleNonDifferentiableElements") {
|
SimpleMathTests.test("TupleNonDifferentiableElements") {
|
||||||
// TF-964: Test tuple with non-tuple-typed adjoint value.
|
// TF-964: Test tuple with non-tuple-typed adjoint value.
|
||||||
func tupleLet(_ x: Tracked<Float>) -> Tracked<Float> {
|
func tupleLet(_ x: Tracked<Float>) -> Tracked<Float> {
|
||||||
@@ -309,6 +312,51 @@ SimpleMathTests.test("TupleNonDifferentiableElements") {
|
|||||||
}
|
}
|
||||||
expectEqual((3, 1), valueWithGradient(at: 3, of: wrapper))
|
expectEqual((3, 1), valueWithGradient(at: 3, of: wrapper))
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
SimpleMathTests.test("TupleNonDifferentiableElementsNotTracked") {
|
||||||
|
// TF-964: Test tuple with non-tuple-typed adjoint value.
|
||||||
|
func tupleLet(_ x: Float) -> Float {
|
||||||
|
let tuple = (2 * x, 1)
|
||||||
|
return tuple.0
|
||||||
|
}
|
||||||
|
expectEqual((8, 2), valueWithGradient(at: 4, of: tupleLet))
|
||||||
|
|
||||||
|
func tupleVar(_ x: Float) -> Float {
|
||||||
|
var tuple = (x, 1)
|
||||||
|
tuple.0 = x
|
||||||
|
tuple.1 = 1
|
||||||
|
return tuple.0
|
||||||
|
}
|
||||||
|
expectEqual((3, 1), valueWithGradient(at: 3, of: tupleVar))
|
||||||
|
|
||||||
|
func nested(_ x: Float) -> Float {
|
||||||
|
// Convoluted function computing `x * x`.
|
||||||
|
var tuple: (Int, (Int, Float), Float) = (1, (1, 0), 0)
|
||||||
|
tuple.0 = 1
|
||||||
|
tuple.1.0 = 1
|
||||||
|
tuple.1.1 = x
|
||||||
|
tuple.2 = x
|
||||||
|
return tuple.1.1 * tuple.2
|
||||||
|
}
|
||||||
|
expectEqual((16, 8), valueWithGradient(at: 4, of: nested))
|
||||||
|
|
||||||
|
struct Wrapper<T> {
|
||||||
|
@differentiable(reverse where T : Differentiable)
|
||||||
|
func baz(_ x: T) -> T {
|
||||||
|
var tuple = (1, 1, x, 1)
|
||||||
|
tuple.0 = 1
|
||||||
|
tuple.2 = x
|
||||||
|
tuple.3 = 1
|
||||||
|
return tuple.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func wrapper(_ x: Float) -> Float {
|
||||||
|
let w = Wrapper<Float>()
|
||||||
|
return w.baz(x)
|
||||||
|
}
|
||||||
|
expectEqual((3, 1), valueWithGradient(at: 3, of: wrapper))
|
||||||
|
}
|
||||||
|
|
||||||
// Tests TF-21.
|
// Tests TF-21.
|
||||||
SimpleMathTests.test("StructMemberwiseInitializer") {
|
SimpleMathTests.test("StructMemberwiseInitializer") {
|
||||||
|
|||||||
@@ -21,9 +21,10 @@ func test(arr: [any P]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make sure that we don't pick a concrete `(AnyHashable, AnyHashable) -> Bool` overload.
|
||||||
|
|
||||||
// CHECK: sil private [ossa] @$s34anyhashable_and_operator_filtering4test3arrySayAA1P_pG_tFyAaD_pXEfU_
|
// CHECK: sil private [ossa] @$s34anyhashable_and_operator_filtering4test3arrySayAA1P_pG_tFyAaD_pXEfU_
|
||||||
// CHECK: [[LHS_ARG:%.*]] = alloc_stack $E
|
// CHECK: [[LHS_ARG:%.*]] = alloc_stack $E
|
||||||
// CHECK: [[RHS_ARG:%.*]] = alloc_stack $E
|
// CHECK: [[RHS_ARG:%.*]] = alloc_stack $E
|
||||||
// CHECK: function_ref == infix<A>(_:_:)
|
// CHECK: [[GENERIC_OP:%.*]] = witness_method $E, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||||
// CHECK-NEXT: [[GENERIC_OP:%.*]] = function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF : $@convention(thin) <τ_0_0 where τ_0_0 : RawRepresentable, τ_0_0.RawValue : Equatable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> Bool
|
// CHECK-NEXT: apply [[GENERIC_OP]]<E>([[LHS_ARG]], [[RHS_ARG]], {{.*}})
|
||||||
// CHECK-NEXT: apply [[GENERIC_OP]]<E>([[LHS_ARG]], [[RHS_ARG]])
|
|
||||||
|
|||||||
@@ -8,9 +8,8 @@ struct Value: Equatable, ExpressibleByNilLiteral {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: sil hidden [ossa] @$s13rdar1580631514test1vyAA5ValueV_tF : $@convention(thin) (Value) -> ()
|
// CHECK-LABEL: sil hidden [ossa] @$s13rdar1580631514test1vyAA5ValueV_tF : $@convention(thin) (Value) -> ()
|
||||||
// function_ref static Value.__derived_struct_equals(_:_:)
|
// CHECK: [[EQUALS_REF:%.*]] = witness_method $Value, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||||
// CHECK: [[EQUALS_REF:%.*]] = function_ref @$s13rdar1580631515ValueV23__derived_struct_equalsySbAC_ACtFZ
|
// CHECK-NEXT: apply [[EQUALS_REF]]<Value>({{.*}})
|
||||||
// CHECK-NEXT: apply [[EQUALS_REF]](%0, {{.*}})
|
|
||||||
func test(v: Value) {
|
func test(v: Value) {
|
||||||
_ = v == nil
|
_ = v == nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,6 @@ func testFoo() {
|
|||||||
// CHECK: [[@LINE+1]]:7 | instance-method/Swift | hash(into:) | s:14swift_ide_test9CustomFooV4hash4intoys6HasherVz_tF | {{.*}}Ref
|
// CHECK: [[@LINE+1]]:7 | instance-method/Swift | hash(into:) | s:14swift_ide_test9CustomFooV4hash4intoys6HasherVz_tF | {{.*}}Ref
|
||||||
f.hash(into: &hasher)
|
f.hash(into: &hasher)
|
||||||
hasher.finalize()
|
hasher.finalize()
|
||||||
// CHECK: [[@LINE+1]]:11 | static-method/Swift | __derived_struct_equals(_:_:) | s:14swift_ide_test9CustomFooV23__derived_struct_equalsySbAC_ACtFZ | {{.*}}Ref
|
// CHECK: [[@LINE+1]]:11 | static-method/infix-operator/Swift | ==(_:_:) | s:SQ2eeoiySbx_xtFZ | {{.*}}Ref
|
||||||
_ = f == CustomFoo(a: 0, b: "b")
|
_ = f == CustomFoo(a: 0, b: "b")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,15 +48,15 @@ public enum Alphabet : String {
|
|||||||
|
|
||||||
// CHECK-LABEL: sil [ossa] @$s4main14check_alphabetySiAA8AlphabetOF : $@convention(thin) (Alphabet) -> Int {
|
// CHECK-LABEL: sil [ossa] @$s4main14check_alphabetySiAA8AlphabetOF : $@convention(thin) (Alphabet) -> Int {
|
||||||
public func check_alphabet(_ state : Alphabet) -> Int {
|
public func check_alphabet(_ state : Alphabet) -> Int {
|
||||||
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
// FRAGILE: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||||
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
// RESILIENT: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||||
return state == .E ? 1 : 0
|
return state == .E ? 1 : 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: sil [ossa] @$s4main9compareItySbAA8AlphabetO_ADtF : $@convention(thin) (Alphabet, Alphabet) -> Bool {
|
// CHECK-LABEL: sil [ossa] @$s4main9compareItySbAA8AlphabetO_ADtF : $@convention(thin) (Alphabet, Alphabet) -> Bool {
|
||||||
public func compareIt(_ state : Alphabet, _ rhs: Alphabet) -> Bool {
|
public func compareIt(_ state : Alphabet, _ rhs: Alphabet) -> Bool {
|
||||||
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
// FRAGILE: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||||
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
// RESILIENT: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||||
return state == rhs
|
return state == rhs
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,14 +67,14 @@ public enum AlphabetInt : Int {
|
|||||||
|
|
||||||
// CHECK-LABEL: sil [ossa] @$s4main18check_alphabet_intySiAA11AlphabetIntOF : $@convention(thin) (AlphabetInt) -> Int {
|
// CHECK-LABEL: sil [ossa] @$s4main18check_alphabet_intySiAA11AlphabetIntOF : $@convention(thin) (AlphabetInt) -> Int {
|
||||||
public func check_alphabet_int(_ state : AlphabetInt) -> Int {
|
public func check_alphabet_int(_ state : AlphabetInt) -> Int {
|
||||||
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
// FRAGILE: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||||
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
// RESILIENT: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||||
return state == .E ? 1 : 0
|
return state == .E ? 1 : 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: sil [ossa] @$s4main9compareItySbAA11AlphabetIntO_ADtF : $@convention(thin) (AlphabetInt, AlphabetInt) -> Bool {
|
// CHECK-LABEL: sil [ossa] @$s4main9compareItySbAA11AlphabetIntO_ADtF : $@convention(thin) (AlphabetInt, AlphabetInt) -> Bool {
|
||||||
public func compareIt(_ state : AlphabetInt, _ rhs: AlphabetInt) -> Bool {
|
public func compareIt(_ state : AlphabetInt, _ rhs: AlphabetInt) -> Bool {
|
||||||
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
// FRAGILE: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||||
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
// RESILIENT: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||||
return state == rhs
|
return state == rhs
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
|
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=SILGEN
|
||||||
|
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s --check-prefix=OPTIMIZED
|
||||||
|
|
||||||
|
// Operators are no longer devirtualized at AST level, it's done during SIL optimization.
|
||||||
|
|
||||||
infix operator +++
|
infix operator +++
|
||||||
|
|
||||||
@@ -11,9 +14,13 @@ struct Branch : Twig {
|
|||||||
static func doIt(_: Branch, _: Branch) {}
|
static func doIt(_: Branch, _: Branch) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () {
|
// SILGEN-LABEL: sil hidden [ossa] @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () {
|
||||||
// CHECK: function_ref @$s18protocol_operators6BranchV4doItyyAC_ACtFZ : $@convention(method) (Branch, Branch, @thin Branch.Type) -> ()
|
// SILGEN: witness_method $Branch, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> ()
|
||||||
// CHECK: return
|
// SILGEN: return
|
||||||
|
|
||||||
|
// OPTIMIZED-LABEL: sil hidden @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () {
|
||||||
|
// OPTIMIZED: function_ref @$s18protocol_operators6BranchV4doItyyAC_ACtFZ
|
||||||
|
// OPTIMIZED: return
|
||||||
func useBranch(_ b: Branch) {
|
func useBranch(_ b: Branch) {
|
||||||
b +++ b
|
b +++ b
|
||||||
}
|
}
|
||||||
@@ -28,11 +35,17 @@ class Stuck : Stick, ExpressibleByIntegerLiteral {
|
|||||||
required init(integerLiteral: Int) {}
|
required init(integerLiteral: Int) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () {
|
// SILGEN-LABEL: sil hidden [ossa] @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () {
|
||||||
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
// SILGEN: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||||
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
// SILGEN: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||||
// CHECK: witness_method $Stuck, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
// SILGEN: witness_method $Stuck, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
// CHECK: return
|
// SILGEN: return
|
||||||
|
|
||||||
|
// OPTIMIZED-LABEL: sil hidden @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () {
|
||||||
|
// OPTIMIZED: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||||
|
// OPTIMIZED: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||||
|
// OPTIMIZED: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||||
|
// OPTIMIZED: return
|
||||||
func useStick(_ a: Stuck, _ b: Stick) {
|
func useStick(_ a: Stuck, _ b: Stick) {
|
||||||
_ = a +++ b
|
_ = a +++ b
|
||||||
_ = b +++ b
|
_ = b +++ b
|
||||||
@@ -49,10 +62,15 @@ class Rope : Twine<Int>, ExpressibleByIntegerLiteral {
|
|||||||
required init(integerLiteral: Int) {}
|
required init(integerLiteral: Int) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () {
|
// SILGEN-LABEL: sil hidden [ossa] @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () {
|
||||||
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
// SILGEN: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||||
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
// SILGEN: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||||
// CHECK: witness_method $Rope, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
// SILGEN: witness_method $Rope, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
|
||||||
|
// OPTIMIZED-LABEL: sil hidden @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () {
|
||||||
|
// OPTIMIZED: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||||
|
// OPTIMIZED: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||||
|
// OPTIMIZED: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||||
func useRope(_ r: Rope, _ s: Rope) {
|
func useRope(_ r: Rope, _ s: Rope) {
|
||||||
_ = r +++ s
|
_ = r +++ s
|
||||||
_ = s +++ s
|
_ = s +++ s
|
||||||
|
|||||||
Reference in New Issue
Block a user