mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Merge pull request #84800 from xedin/remove-csapply-operator-devirt
[CSApply] Don't attempt operator devirtualization
This commit is contained in:
@@ -43,6 +43,7 @@
|
||||
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
|
||||
#include "swift/SILOptimizer/PassManager/Passes.h"
|
||||
#include "swift/SILOptimizer/PassManager/Transforms.h"
|
||||
#include "swift/SILOptimizer/Utils/Devirtualize.h"
|
||||
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
|
||||
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
|
||||
#include "llvm/ADT/APSInt.h"
|
||||
@@ -100,6 +101,24 @@ private:
|
||||
SILBuilder &builder, SILLocation loc,
|
||||
DifferentiationInvoker invoker);
|
||||
|
||||
/// Creates and canonicalizes private differentiability witness for
|
||||
/// `originalFn`, differentiating with respect to
|
||||
/// `desiredParameterIndices`. Returns `nullptr` on failure signifying that a
|
||||
/// diagnostic has been emitted using `invoker`.
|
||||
SILDifferentiabilityWitness *createPrivateDifferentiabilityWitness(
|
||||
CanSILFunctionType originalFnTy, SILFunction *originalFn,
|
||||
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
|
||||
SILValue original, DifferentiationInvoker invoker);
|
||||
|
||||
/// Resolves differentiability witness for `originalFn`. Either by looking up
|
||||
/// for registered one with respect to posible superset of `desiredIndices`,
|
||||
/// or creating a new private one. Returns `nullptr` on failure signifying
|
||||
/// that a diagnostic has been emitted using `invoker`.
|
||||
SILDifferentiabilityWitness *getOrCreateMinimalDifferentiabilitywitness(
|
||||
CanSILFunctionType originalFnTy, SILFunction *originalFn,
|
||||
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
|
||||
SILValue original, DifferentiationInvoker invoker);
|
||||
|
||||
/// Emits a reference to a derivative function of `original`, differentiated
|
||||
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
|
||||
/// the derivative function and the actual indices that the derivative
|
||||
@@ -527,12 +546,142 @@ SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction(
|
||||
return silFnIt->second;
|
||||
}
|
||||
|
||||
SILDifferentiabilityWitness *
|
||||
DifferentiationTransformer::createPrivateDifferentiabilityWitness(
|
||||
CanSILFunctionType originalFnTy, SILFunction *originalFn,
|
||||
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
|
||||
SILValue original, DifferentiationInvoker invoker) {
|
||||
|
||||
// Check non-differentiable cases before creating a new private
|
||||
// differentiability witness.
|
||||
|
||||
// If the function is intentionally marked as being opaque to
|
||||
// differentiation, then we should not create a task for it.
|
||||
if (originalFn->hasSemanticsAttr("autodiff.opaque")) {
|
||||
context.emitNondifferentiabilityError(
|
||||
original, invoker, diag::autodiff_opaque_function_not_differentiable);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Check and diagnose non-differentiable arguments.
|
||||
for (auto [paramIndex, param] :
|
||||
llvm::enumerate(originalFnTy->getParameters())) {
|
||||
if (!desiredParameterIndices->contains(paramIndex))
|
||||
continue;
|
||||
|
||||
SILType paramType = param.getSILStorageInterfaceType();
|
||||
if (!paramType.isDifferentiable(context.getModule())) {
|
||||
context.emitNondifferentiabilityError(
|
||||
original, invoker, diag::autodiff_nondifferentiable_argument);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Check and diagnose non-differentiable results.
|
||||
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
|
||||
unsigned firstYieldResultIndex =
|
||||
firstSemanticParamResultIdx +
|
||||
originalFnTy->getNumAutoDiffSemanticResultsParameters();
|
||||
for (auto resultIndex : desiredResultIndices->getIndices()) {
|
||||
SILType resultType;
|
||||
if (resultIndex >= firstYieldResultIndex) {
|
||||
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
|
||||
auto yield = originalFnTy->getYields()[yieldResultIndex];
|
||||
// We can only differentiate indirect yields. This should be diagnosed
|
||||
// earlier in VJPCloner.
|
||||
assert(yield.isAutoDiffSemanticResult() && "unsupported result");
|
||||
resultType = yield.getSILStorageInterfaceType();
|
||||
} else if (resultIndex >= firstSemanticParamResultIdx) {
|
||||
auto semanticResultParamIdx = resultIndex - firstSemanticParamResultIdx;
|
||||
auto semanticResultParam = *std::next(
|
||||
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
|
||||
semanticResultParamIdx);
|
||||
resultType = semanticResultParam.getSILStorageInterfaceType();
|
||||
} else {
|
||||
resultType =
|
||||
originalFnTy->getResults()[resultIndex].getSILStorageInterfaceType();
|
||||
}
|
||||
|
||||
if (!resultType || !resultType.isDifferentiable(context.getModule())) {
|
||||
context.emitNondifferentiabilityError(
|
||||
original, invoker, diag::autodiff_nondifferentiable_result);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Check and diagnose external declarations.
|
||||
if (originalFn->isExternalDeclaration()) {
|
||||
context.emitNondifferentiabilityError(
|
||||
original, invoker, diag::autodiff_external_nondifferentiable_function);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Soundness check passed. Create a new differentiability witness
|
||||
GenericSignature contextualDerivativeGenSig = GenericSignature();
|
||||
if (invoker.getKind() ==
|
||||
DifferentiationInvoker::Kind::IndirectDifferentiation)
|
||||
contextualDerivativeGenSig = invoker.getIndirectDifferentiation()
|
||||
.second->getDerivativeGenericSignature();
|
||||
auto derivativeConstrainedGenSig =
|
||||
autodiff::getConstrainedDerivativeGenericSignature(
|
||||
originalFnTy, desiredParameterIndices, desiredResultIndices,
|
||||
contextualDerivativeGenSig, LookUpConformanceInModule());
|
||||
|
||||
auto *witness = SILDifferentiabilityWitness::createDefinition(
|
||||
context.getModule(), SILLinkage::Private, originalFn,
|
||||
DifferentiabilityKind::Reverse, desiredParameterIndices,
|
||||
desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr,
|
||||
/*vjp*/ nullptr, /*isSerialized*/ false);
|
||||
if (canonicalizeDifferentiabilityWitness(witness, invoker, IsNotSerialized))
|
||||
return nullptr;
|
||||
|
||||
return witness;
|
||||
}
|
||||
|
||||
SILDifferentiabilityWitness *
|
||||
DifferentiationTransformer::getOrCreateMinimalDifferentiabilitywitness(
|
||||
CanSILFunctionType originalFnTy, SILFunction *originalFn,
|
||||
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
|
||||
SILValue original, DifferentiationInvoker invoker) {
|
||||
// NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has
|
||||
// parameters corresponding to captured variables.
|
||||
// TODO: If possible, change `autodiff::getLoweredParameterIndices` to
|
||||
// take `CaptureInfo` into account.
|
||||
if (originalFnTy->getNumParameters() >
|
||||
desiredParameterIndices->getCapacity()) {
|
||||
desiredParameterIndices = desiredParameterIndices->extendingCapacity(
|
||||
context.getASTContext(), originalFnTy->getNumParameters());
|
||||
}
|
||||
|
||||
// Look up a differentiability witness with the exact configuration.
|
||||
auto *minimalWitness = getExactDifferentiabilityWitness(
|
||||
context.getModule(), originalFn, desiredParameterIndices,
|
||||
desiredResultIndices);
|
||||
|
||||
// Otherwise, look up a differentiability witness with a minimal superset
|
||||
// configuration.
|
||||
if (!minimalWitness)
|
||||
minimalWitness = getOrCreateMinimalASTDifferentiabilityWitness(
|
||||
context.getModule(), originalFn, DifferentiabilityKind::Reverse,
|
||||
desiredParameterIndices, desiredResultIndices);
|
||||
|
||||
// Finally, try to create private differentiability witness
|
||||
if (!minimalWitness)
|
||||
minimalWitness = createPrivateDifferentiabilityWitness(
|
||||
originalFnTy, originalFn, desiredParameterIndices, desiredResultIndices,
|
||||
original, invoker);
|
||||
|
||||
return minimalWitness;
|
||||
}
|
||||
|
||||
std::optional<std::pair<SILValue, AutoDiffConfig>>
|
||||
DifferentiationTransformer::emitDerivativeFunctionReference(
|
||||
SILBuilder &builder, const AutoDiffConfig &desiredConfig,
|
||||
AutoDiffDerivativeFunctionKind kind, SILValue original,
|
||||
DifferentiationInvoker invoker,
|
||||
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
|
||||
SILFunction *parentFn = original->getFunction();
|
||||
|
||||
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
|
||||
// matches the given kind and desired differentiation parameter indices,
|
||||
// simply extract the derivative function of its function operand, retain the
|
||||
@@ -576,111 +725,16 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
|
||||
auto loc = originalFRI->getLoc();
|
||||
auto *originalFn = originalFRI->getReferencedFunction();
|
||||
auto originalFnTy = originalFn->getLoweredFunctionType();
|
||||
auto *desiredParameterIndices = desiredConfig.parameterIndices;
|
||||
auto *desiredResultIndices = desiredConfig.resultIndices;
|
||||
// NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has
|
||||
// parameters corresponding to captured variables.
|
||||
// TODO: If possible, change `autodiff::getLoweredParameterIndices` to
|
||||
// take `CaptureInfo` into account.
|
||||
if (originalFnTy->getNumParameters() >
|
||||
desiredParameterIndices->getCapacity()) {
|
||||
desiredParameterIndices = desiredParameterIndices->extendingCapacity(
|
||||
context.getASTContext(), originalFnTy->getNumParameters());
|
||||
}
|
||||
// Look up a differentiability witness with the exact configuration.
|
||||
auto *minimalWitness = getExactDifferentiabilityWitness(
|
||||
context.getModule(), originalFn, desiredParameterIndices,
|
||||
desiredResultIndices);
|
||||
// Otherwise, look up a differentiability witness with a minimal superset
|
||||
// configuration.
|
||||
|
||||
auto *minimalWitness = getOrCreateMinimalDifferentiabilitywitness(
|
||||
originalFnTy, originalFn, desiredConfig.parameterIndices,
|
||||
desiredConfig.resultIndices, original, invoker);
|
||||
|
||||
// All non-differentiability cases should be diagnosted
|
||||
if (!minimalWitness)
|
||||
minimalWitness = getOrCreateMinimalASTDifferentiabilityWitness(
|
||||
context.getModule(), originalFn, DifferentiabilityKind::Reverse,
|
||||
desiredParameterIndices, desiredResultIndices);
|
||||
// If no minimal witness exists, check non-differentiable cases before
|
||||
// creating a new private differentiability witness.
|
||||
if (!minimalWitness) {
|
||||
// If the function is intentionally marked as being opaque to
|
||||
// differentiation, then we should not create a task for it.
|
||||
if (originalFn->hasSemanticsAttr("autodiff.opaque")) {
|
||||
context.emitNondifferentiabilityError(
|
||||
original, invoker,
|
||||
diag::autodiff_opaque_function_not_differentiable);
|
||||
return std::nullopt;
|
||||
}
|
||||
// Check and diagnose non-differentiable arguments.
|
||||
auto originalFnTy = originalFn->getLoweredFunctionType();
|
||||
for (unsigned paramIndex : range(originalFnTy->getNumParameters())) {
|
||||
if (desiredConfig.isWrtParameter(paramIndex) &&
|
||||
!originalFnTy->getParameters()[paramIndex]
|
||||
.getSILStorageInterfaceType()
|
||||
.isDifferentiable(context.getModule())) {
|
||||
auto diag = context.emitNondifferentiabilityError(
|
||||
original, invoker, diag::autodiff_nondifferentiable_argument);
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
// Check and diagnose non-differentiable results.
|
||||
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
|
||||
unsigned firstYieldResultIndex = originalFnTy->getNumResults() +
|
||||
originalFnTy->getNumAutoDiffSemanticResultsParameters();
|
||||
for (auto resultIndex : desiredResultIndices->getIndices()) {
|
||||
SILType resultType;
|
||||
if (resultIndex >= firstYieldResultIndex) {
|
||||
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
|
||||
auto yield = originalFnTy->getYields()[yieldResultIndex];
|
||||
// We can only differentiate indirect yields. This should be diagnosed
|
||||
// earlier in VJPCloner.
|
||||
assert(yield.isAutoDiffSemanticResult() && "unsupported result");
|
||||
resultType = yield.getSILStorageInterfaceType();
|
||||
} else if (resultIndex >= firstSemanticParamResultIdx) {
|
||||
auto semanticResultParamIdx = resultIndex - firstSemanticParamResultIdx;
|
||||
auto semanticResultParam =
|
||||
*std::next(originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
|
||||
semanticResultParamIdx);
|
||||
resultType = semanticResultParam.getSILStorageInterfaceType();
|
||||
} else {
|
||||
resultType = originalFnTy->getResults()[resultIndex]
|
||||
.getSILStorageInterfaceType();
|
||||
}
|
||||
if (!resultType || !resultType.isDifferentiable(context.getModule())) {
|
||||
context.emitNondifferentiabilityError(
|
||||
original, invoker, diag::autodiff_nondifferentiable_result);
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
// Check and diagnose external declarations.
|
||||
if (originalFn->isExternalDeclaration()) {
|
||||
context.emitNondifferentiabilityError(
|
||||
original, invoker,
|
||||
diag::autodiff_external_nondifferentiable_function);
|
||||
return std::nullopt;
|
||||
}
|
||||
// Soundness check passed. Create a new differentiability witness and
|
||||
// canonicalize it.
|
||||
GenericSignature contextualDerivativeGenSig = GenericSignature();
|
||||
if (invoker.getKind() ==
|
||||
DifferentiationInvoker::Kind::IndirectDifferentiation)
|
||||
contextualDerivativeGenSig =
|
||||
invoker.getIndirectDifferentiation()
|
||||
.second->getDerivativeGenericSignature();
|
||||
auto derivativeConstrainedGenSig =
|
||||
autodiff::getConstrainedDerivativeGenericSignature(
|
||||
originalFn->getLoweredFunctionType(),
|
||||
desiredParameterIndices, desiredResultIndices,
|
||||
contextualDerivativeGenSig,
|
||||
LookUpConformanceInModule());
|
||||
minimalWitness = SILDifferentiabilityWitness::createDefinition(
|
||||
context.getModule(), SILLinkage::Private, originalFn,
|
||||
DifferentiabilityKind::Reverse, desiredParameterIndices,
|
||||
desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr,
|
||||
/*vjp*/ nullptr, /*isSerialized*/ false);
|
||||
if (canonicalizeDifferentiabilityWitness(minimalWitness, invoker,
|
||||
IsNotSerialized))
|
||||
return std::nullopt;
|
||||
}
|
||||
assert(minimalWitness);
|
||||
if (original->getFunction()->isSerialized()) {
|
||||
return std::nullopt;
|
||||
|
||||
if (parentFn->isSerialized()) {
|
||||
bool isWitnessPublic;
|
||||
if (SILFunction *unwrappedFn =
|
||||
getUnwrappedCurryThunkFunction(originalFn)) {
|
||||
@@ -691,15 +745,12 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
|
||||
// private linkage.
|
||||
isWitnessPublic = hasPublicVisibility(unwrappedFn->getLinkage());
|
||||
} else if (originalFn->getDeclRef().getAbstractClosureExpr() &&
|
||||
originalFRI->getFunction()
|
||||
->getDeclRef()
|
||||
.isDefaultArgGenerator()) {
|
||||
parentFn->getDeclRef().isDefaultArgGenerator()) {
|
||||
// If we reference a closure from inside default argument generator,
|
||||
// check against generator's visibility. If the function having this
|
||||
// default argument has public visibility, it's OK to have a closure
|
||||
// (which always has private visibility) as its default value.
|
||||
isWitnessPublic =
|
||||
hasPublicVisibility(originalFRI->getFunction()->getLinkage());
|
||||
isWitnessPublic = hasPublicVisibility(parentFn->getLinkage());
|
||||
} else {
|
||||
isWitnessPublic = hasPublicVisibility(minimalWitness->getLinkage());
|
||||
}
|
||||
@@ -709,16 +760,16 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
|
||||
// FIXME: This is not a very robust way of determining if the function
|
||||
// is a default argument. Also, we have not exhaustively listed all the
|
||||
// kinds of fragility.
|
||||
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
|
||||
if (parentFn->getLinkage() == SILLinkage::PublicNonABI)
|
||||
fragileKind = DefaultArgument;
|
||||
context.emitNondifferentiabilityError(
|
||||
original, invoker, diag::autodiff_private_derivative_from_fragile,
|
||||
fragileKind,
|
||||
isa_and_nonnull<AbstractClosureExpr>(
|
||||
originalFRI->getLoc().getAsASTNode<Expr>()));
|
||||
isa_and_nonnull<AbstractClosureExpr>(loc.getAsASTNode<Expr>()));
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(TF-482): Move generic requirement checking logic to
|
||||
// `getExactDifferentiabilityWitness` and
|
||||
// `getOrCreateMinimalASTDifferentiabilityWitness`.
|
||||
@@ -726,18 +777,17 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
|
||||
// By default, use the forwarding substitution map of the original function.
|
||||
// If the original callee is a `partial_apply` or `apply` instruction, use
|
||||
// its substitution map instead.
|
||||
auto substMap = original->getFunction()->getForwardingSubstitutionMap();
|
||||
if (auto *pai =
|
||||
peerThroughFunctionConversions<PartialApplyInst>(original)) {
|
||||
auto substMap = parentFn->getForwardingSubstitutionMap();
|
||||
if (auto *pai = peerThroughFunctionConversions<PartialApplyInst>(original))
|
||||
substMap = pai->getSubstitutionMap();
|
||||
} else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original)) {
|
||||
else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original))
|
||||
substMap = ai->getSubstitutionMap();
|
||||
}
|
||||
if (diagnoseUnsatisfiedRequirements(
|
||||
context, original->getType().castTo<SILFunctionType>(),
|
||||
minimalWitness->getDerivativeGenericSignature(), substMap, invoker,
|
||||
original.getLoc().getSourceLoc()))
|
||||
return std::nullopt;
|
||||
|
||||
DifferentiabilityWitnessFunctionKind witnessKind;
|
||||
switch (kind) {
|
||||
case AutoDiffDerivativeFunctionKind::JVP:
|
||||
@@ -764,6 +814,79 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
|
||||
if (auto *witnessMethod =
|
||||
peerThroughFunctionConversions<WitnessMethodInst>(original)) {
|
||||
auto loc = witnessMethod->getLoc();
|
||||
// See if we can derive derivatives for a method statically
|
||||
auto [method, table] = lookUpFunctionInWitnessTable(
|
||||
witnessMethod, SILModule::LinkingMode::LinkNormal);
|
||||
if (method) {
|
||||
auto *originalFn = method;
|
||||
auto originalFnTy = method->getLoweredFunctionTypeInContext(
|
||||
TypeExpansionContext(*original->getFunction()));
|
||||
|
||||
auto *minimalWitness = getOrCreateMinimalDifferentiabilitywitness(
|
||||
originalFnTy, originalFn, desiredConfig.parameterIndices,
|
||||
desiredConfig.resultIndices, original, invoker);
|
||||
|
||||
// All non-differentiability cases should be diagnosted
|
||||
if (!minimalWitness)
|
||||
return std::nullopt;
|
||||
|
||||
if (parentFn->isSerialized() &&
|
||||
!hasPublicVisibility(minimalWitness->getLinkage())) {
|
||||
enum { Inlinable = 0, DefaultArgument = 1 };
|
||||
unsigned fragileKind = Inlinable;
|
||||
// FIXME: This is not a very robust way of determining if the function
|
||||
// is a default argument. Also, we have not exhaustively listed all the
|
||||
// kinds of fragility.
|
||||
if (parentFn->getLinkage() == SILLinkage::PublicNonABI)
|
||||
fragileKind = DefaultArgument;
|
||||
context.emitNondifferentiabilityError(
|
||||
original, invoker, diag::autodiff_private_derivative_from_fragile,
|
||||
fragileKind,
|
||||
isa_and_nonnull<AbstractClosureExpr>(loc.getAsASTNode<Expr>()));
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto substMap = parentFn->getForwardingSubstitutionMap();
|
||||
if (auto *pai =
|
||||
peerThroughFunctionConversions<PartialApplyInst>(original))
|
||||
substMap =
|
||||
getWitnessMethodSubstitutions(context.getModule(), pai, originalFn,
|
||||
witnessMethod->getConformance());
|
||||
else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original))
|
||||
substMap =
|
||||
getWitnessMethodSubstitutions(context.getModule(), ai, originalFn,
|
||||
witnessMethod->getConformance());
|
||||
if (diagnoseUnsatisfiedRequirements(
|
||||
context, original->getType().castTo<SILFunctionType>(),
|
||||
minimalWitness->getDerivativeGenericSignature(), substMap,
|
||||
invoker, original.getLoc().getSourceLoc()))
|
||||
return std::nullopt;
|
||||
|
||||
DifferentiabilityWitnessFunctionKind witnessKind;
|
||||
switch (kind) {
|
||||
case AutoDiffDerivativeFunctionKind::JVP:
|
||||
witnessKind = DifferentiabilityWitnessFunctionKind::JVP;
|
||||
break;
|
||||
case AutoDiffDerivativeFunctionKind::VJP:
|
||||
witnessKind = DifferentiabilityWitnessFunctionKind::VJP;
|
||||
break;
|
||||
}
|
||||
|
||||
auto *derivativeFnRef = builder.createDifferentiabilityWitnessFunction(
|
||||
loc, witnessKind, minimalWitness);
|
||||
|
||||
auto convertedRef = reapplyFunctionConversion(
|
||||
context, derivativeFnRef, witnessMethod, original, builder, loc,
|
||||
newBuffersToDealloc, desiredConfig.parameterIndices,
|
||||
desiredConfig.resultIndices,
|
||||
derivativeFnRef->getType()
|
||||
.getASTType()
|
||||
->castTo<SILFunctionType>()
|
||||
->getSubstGenericSignature());
|
||||
|
||||
return std::make_pair(convertedRef, minimalWitness->getConfig());
|
||||
}
|
||||
|
||||
auto requirementDeclRef = witnessMethod->getMember();
|
||||
auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl();
|
||||
// If requirement declaration does not have any derivative function
|
||||
|
||||
@@ -588,66 +588,6 @@ namespace {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns None if the AST does not contain enough information to recover
|
||||
// substitutions; this is different from an Optional(SubstitutionMap()),
|
||||
// indicating a valid call to a non-generic operator.
|
||||
std::optional<SubstitutionMap> getOperatorSubstitutions(ValueDecl *witness,
|
||||
Type refType) {
|
||||
// We have to recover substitutions in this hacky way because
|
||||
// the AST does not retain enough information to devirtualize
|
||||
// calls like this.
|
||||
auto witnessType = witness->getInterfaceType();
|
||||
|
||||
// Compute the substitutions.
|
||||
auto *gft = witnessType->getAs<GenericFunctionType>();
|
||||
if (gft == nullptr) {
|
||||
if (refType->isEqual(witnessType))
|
||||
return SubstitutionMap();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto sig = gft->getGenericSignature();
|
||||
auto *env = sig.getGenericEnvironment();
|
||||
|
||||
witnessType = FunctionType::get(gft->getParams(),
|
||||
gft->getResult(),
|
||||
gft->getExtInfo());
|
||||
witnessType = env->mapTypeIntoContext(witnessType);
|
||||
|
||||
TypeSubstitutionMap subs;
|
||||
auto substType = witnessType->substituteBindingsTo(
|
||||
refType,
|
||||
[&](ArchetypeType *origType, CanType substType) -> CanType {
|
||||
if (auto gpType = dyn_cast<GenericTypeParamType>(
|
||||
origType->getInterfaceType()->getCanonicalType()))
|
||||
subs[gpType] = substType;
|
||||
|
||||
return substType;
|
||||
});
|
||||
|
||||
// If substitution failed, it means that the protocol requirement type
|
||||
// and the witness type did not match up. The only time that this
|
||||
// should happen is when the witness is defined in a base class and
|
||||
// the actual call uses a derived class. For example,
|
||||
//
|
||||
// protocol P { func +(lhs: Self, rhs: Self) }
|
||||
// class Base : P { func +(lhs: Base, rhs: Base) {} }
|
||||
// class Derived : Base {}
|
||||
//
|
||||
// If we enter this code path with two operands of type Derived,
|
||||
// we know we're calling the protocol requirement P.+, with a
|
||||
// substituted type of (Derived, Derived) -> (). But the type of
|
||||
// the witness is (Base, Base) -> (). Just bail out and make a
|
||||
// witness method call in this rare case; SIL mandatory optimizations
|
||||
// will likely devirtualize it anyway.
|
||||
if (!substType)
|
||||
return std::nullopt;
|
||||
|
||||
return SubstitutionMap::get(sig,
|
||||
QueryTypeSubstitutionMap{subs},
|
||||
LookUpConformanceInModule());
|
||||
}
|
||||
|
||||
/// Determine whether the given reference is to a method on
|
||||
/// a remote distributed actor in the given context.
|
||||
bool isDistributedThunk(ConcreteDeclRef ref, Expr *context);
|
||||
@@ -674,65 +614,6 @@ namespace {
|
||||
|
||||
auto baseTy = getBaseType(adjustedFullType->castTo<FunctionType>());
|
||||
|
||||
// Handle operator requirements found in protocols.
|
||||
if (auto proto = dyn_cast<ProtocolDecl>(decl->getDeclContext())) {
|
||||
bool isCurried = shouldBuildCurryThunk(choice, /*baseIsInstance=*/false);
|
||||
|
||||
// If we have a concrete conformance, build a call to the witness.
|
||||
//
|
||||
// FIXME: This is awful. We should be able to handle this as a call to
|
||||
// the protocol requirement with Self == the concrete type, and SILGen
|
||||
// (or later) can devirtualize as appropriate.
|
||||
auto conformance = checkConformance(baseTy, proto);
|
||||
if (conformance.isConcrete()) {
|
||||
if (auto witness = conformance.getConcrete()->getWitnessDecl(decl)) {
|
||||
bool isMemberOperator = witness->getDeclContext()->isTypeContext();
|
||||
|
||||
if (!isMemberOperator || !isCurried) {
|
||||
// The fullType was computed by substituting the protocol
|
||||
// requirement so it always has a (Self) -> ... curried
|
||||
// application. Strip it off if the witness was a top-level
|
||||
// function.
|
||||
Type refType;
|
||||
if (isMemberOperator)
|
||||
refType = adjustedFullType;
|
||||
else
|
||||
refType = adjustedFullType->castTo<AnyFunctionType>()->getResult();
|
||||
|
||||
// Build the AST for the call to the witness.
|
||||
auto subMap = getOperatorSubstitutions(witness, refType);
|
||||
if (subMap) {
|
||||
ConcreteDeclRef witnessRef(witness, *subMap);
|
||||
auto declRefExpr = new (ctx) DeclRefExpr(witnessRef, loc,
|
||||
/*Implicit=*/false);
|
||||
declRefExpr->setFunctionRefInfo(choice.getFunctionRefInfo());
|
||||
cs.setType(declRefExpr, refType);
|
||||
|
||||
Expr *refExpr;
|
||||
if (isMemberOperator) {
|
||||
// If the operator is a type member, add the implicit
|
||||
// (Self) -> ... call.
|
||||
Expr *base =
|
||||
TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy,
|
||||
ctx);
|
||||
cs.setType(base, MetatypeType::get(baseTy));
|
||||
|
||||
refExpr =
|
||||
DotSyntaxCallExpr::create(ctx, declRefExpr, SourceLoc(),
|
||||
Argument::unlabeled(base));
|
||||
auto refType = adjustedFullType->castTo<FunctionType>()->getResult();
|
||||
cs.setType(refExpr, refType);
|
||||
} else {
|
||||
refExpr = declRefExpr;
|
||||
}
|
||||
|
||||
return forceUnwrapIfExpected(refExpr, locator);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build a reference to the member.
|
||||
Expr *base =
|
||||
TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy, ctx);
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
// REQUIRES: executable_test
|
||||
|
||||
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.
|
||||
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||
We cannot use `SIMD` :( */
|
||||
// XFAIL: *
|
||||
|
||||
import _Differentiation
|
||||
import StdlibUnittest
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
// NOTE: Verify whether forward-mode differentiation crashes. It currently does.
|
||||
// RUN: not --crash %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil %s
|
||||
// REQUIRES: executable_test
|
||||
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||
We cannot use `Tracked<T>` :( */
|
||||
// XFAIL: *
|
||||
|
||||
import StdlibUnittest
|
||||
import DifferentiationUnittest
|
||||
|
||||
@@ -9,6 +9,144 @@ import DifferentiationUnittest
|
||||
|
||||
var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty")
|
||||
|
||||
struct TangentSpace : AdditiveArithmetic {
|
||||
let x, y: Float
|
||||
}
|
||||
|
||||
extension TangentSpace : Differentiable {
|
||||
typealias TangentVector = TangentSpace
|
||||
}
|
||||
|
||||
struct Space {
|
||||
/// `x` is a computed property with a custom vjp.
|
||||
var x: Float {
|
||||
@differentiable(reverse)
|
||||
get { storedX }
|
||||
set { storedX = newValue }
|
||||
}
|
||||
|
||||
@derivative(of: x)
|
||||
func vjpX() -> (value: Float, pullback: (Float) -> TangentSpace) {
|
||||
return (x, { v in TangentSpace(x: v, y: 0) } )
|
||||
}
|
||||
|
||||
private var storedX: Float
|
||||
|
||||
@differentiable(reverse)
|
||||
var y: Float
|
||||
|
||||
init(x: Float, y: Float) {
|
||||
self.storedX = x
|
||||
self.y = y
|
||||
}
|
||||
}
|
||||
|
||||
extension Space : Differentiable {
|
||||
typealias TangentVector = TangentSpace
|
||||
mutating func move(by offset: TangentSpace) {
|
||||
x.move(by: offset.x)
|
||||
y.move(by: offset.y)
|
||||
}
|
||||
}
|
||||
|
||||
E2EDifferentiablePropertyTests.test("computed property") {
|
||||
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
|
||||
return 2 * point.x
|
||||
}
|
||||
let expectedGrad = TangentSpace(x: 2, y: 0)
|
||||
expectEqual(expectedGrad, actualGrad)
|
||||
}
|
||||
|
||||
E2EDifferentiablePropertyTests.test("stored property") {
|
||||
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
|
||||
return 3 * point.y
|
||||
}
|
||||
let expectedGrad = TangentSpace(x: 0, y: 3)
|
||||
expectEqual(expectedGrad, actualGrad)
|
||||
}
|
||||
|
||||
struct GenericMemberWrapper<T : Differentiable> : Differentiable {
|
||||
// Stored property.
|
||||
@differentiable(reverse)
|
||||
var x: T
|
||||
|
||||
func vjpX() -> (T, (T.TangentVector) -> GenericMemberWrapper.TangentVector) {
|
||||
return (x, { TangentVector(x: $0) })
|
||||
}
|
||||
}
|
||||
|
||||
E2EDifferentiablePropertyTests.test("generic stored property") {
|
||||
let actualGrad = gradient(at: GenericMemberWrapper<Float>(x: 1)) { point in
|
||||
return 2 * point.x
|
||||
}
|
||||
let expectedGrad = GenericMemberWrapper<Float>.TangentVector(x: 2)
|
||||
expectEqual(expectedGrad, actualGrad)
|
||||
}
|
||||
|
||||
struct ProductSpaceSelfTangent : AdditiveArithmetic {
|
||||
let x, y: Float
|
||||
}
|
||||
|
||||
extension ProductSpaceSelfTangent : Differentiable {
|
||||
typealias TangentVector = ProductSpaceSelfTangent
|
||||
}
|
||||
|
||||
E2EDifferentiablePropertyTests.test("fieldwise product space, self tangent") {
|
||||
let actualGrad = gradient(at: ProductSpaceSelfTangent(x: 0, y: 0)) { (point: ProductSpaceSelfTangent) -> Float in
|
||||
return 5 * point.y
|
||||
}
|
||||
let expectedGrad = ProductSpaceSelfTangent(x: 0, y: 5)
|
||||
expectEqual(expectedGrad, actualGrad)
|
||||
}
|
||||
|
||||
struct ProductSpaceOtherTangentTangentSpace : AdditiveArithmetic {
|
||||
let x, y: Float
|
||||
}
|
||||
|
||||
extension ProductSpaceOtherTangentTangentSpace : Differentiable {
|
||||
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
|
||||
}
|
||||
|
||||
struct ProductSpaceOtherTangent {
|
||||
var x, y: Float
|
||||
}
|
||||
|
||||
extension ProductSpaceOtherTangent : Differentiable {
|
||||
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
|
||||
mutating func move(by offset: ProductSpaceOtherTangentTangentSpace) {
|
||||
x.move(by: offset.x)
|
||||
y.move(by: offset.y)
|
||||
}
|
||||
}
|
||||
|
||||
E2EDifferentiablePropertyTests.test("fieldwise product space, other tangent") {
|
||||
let actualGrad = gradient(
|
||||
at: ProductSpaceOtherTangent(x: 0, y: 0)
|
||||
) { (point: ProductSpaceOtherTangent) -> Float in
|
||||
return 7 * point.y
|
||||
}
|
||||
let expectedGrad = ProductSpaceOtherTangentTangentSpace(x: 0, y: 7)
|
||||
expectEqual(expectedGrad, actualGrad)
|
||||
}
|
||||
|
||||
E2EDifferentiablePropertyTests.test("computed property") {
|
||||
struct TF_544 : Differentiable {
|
||||
var value: Float
|
||||
@differentiable(reverse)
|
||||
var computed: Float {
|
||||
get { value }
|
||||
set { value = newValue }
|
||||
}
|
||||
}
|
||||
let actualGrad = gradient(at: TF_544(value: 2.4)) { x in
|
||||
return x.computed * x.computed
|
||||
}
|
||||
let expectedGrad = TF_544.TangentVector(value: 4.8)
|
||||
expectEqual(expectedGrad, actualGrad)
|
||||
}
|
||||
|
||||
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||
We cannot use `Tracked<T>` :(
|
||||
struct TangentSpace : AdditiveArithmetic {
|
||||
let x, y: Tracked<Float>
|
||||
}
|
||||
@@ -144,5 +282,6 @@ E2EDifferentiablePropertyTests.testWithLeakChecking("computed property") {
|
||||
let expectedGrad = TF_544.TangentVector(value: 4.8)
|
||||
expectEqual(expectedGrad, actualGrad)
|
||||
}
|
||||
*/
|
||||
|
||||
runAllTests()
|
||||
|
||||
@@ -8,19 +8,39 @@ var ExistentialTests = TestSuite("Existential")
|
||||
|
||||
protocol A {
|
||||
@differentiable(reverse, wrt: x)
|
||||
func a(_ x: Tracked<Float>) -> Tracked<Float>
|
||||
func a(_ x: Float) -> Float
|
||||
}
|
||||
func b(g: A) -> Tracked<Float> {
|
||||
func b(g: A) -> Float {
|
||||
return gradient(at: 3) { x in g.a(x) }
|
||||
}
|
||||
|
||||
struct B : A {
|
||||
@differentiable(reverse, wrt: x)
|
||||
func a(_ x: Tracked<Float>) -> Tracked<Float> { return x * 5 }
|
||||
func a(_ x: Float) -> Float { return x * 5 }
|
||||
}
|
||||
|
||||
ExistentialTests.testWithLeakChecking("Existential method VJP") {
|
||||
ExistentialTests.test("Existential method VJP-Tracked") {
|
||||
expectEqual(5.0, b(g: B()))
|
||||
}
|
||||
|
||||
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||
We cannot use `Tracked<T>` :(
|
||||
protocol ATracked {
|
||||
@differentiable(reverse, wrt: x)
|
||||
func a(_ x: Tracked<Float>) -> Tracked<Float>
|
||||
}
|
||||
func b(g: ATracked) -> Tracked<Float> {
|
||||
return gradient(at: 3) { x in g.a(x) }
|
||||
}
|
||||
|
||||
struct BTracked : ATracked {
|
||||
@differentiable(reverse, wrt: x)
|
||||
func a(_ x: Tracked<Float>) -> Tracked<Float> { return x * 5 }
|
||||
}
|
||||
|
||||
ExistentialTests.testWithLeakChecking("Existential method VJP-Tracked") {
|
||||
expectEqual(5.0, b(g: BTracked()))
|
||||
}
|
||||
*/
|
||||
|
||||
runAllTests()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
|
||||
// REQUIRES: executable_test
|
||||
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||
We cannot use `SIMD` :( */
|
||||
// XFAIL: *
|
||||
|
||||
import StdlibUnittest
|
||||
import DifferentiationUnittest
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
|
||||
// REQUIRES: executable_test
|
||||
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||
We cannot use `Tracked<T>` :( */
|
||||
// XFAIL: *
|
||||
|
||||
import StdlibUnittest
|
||||
import DifferentiationUnittest
|
||||
|
||||
@@ -9,50 +9,50 @@ var MethodTests = TestSuite("Method")
|
||||
// ==== Tests with generated adjoint ====
|
||||
|
||||
struct Parameter : Equatable {
|
||||
private let storedX: Tracked<Float>
|
||||
private let storedX: Float
|
||||
@differentiable(reverse, wrt: (self))
|
||||
var x: Tracked<Float> {
|
||||
var x: Float {
|
||||
return storedX
|
||||
}
|
||||
|
||||
init(x: Tracked<Float>) {
|
||||
init(x: Float) {
|
||||
storedX = x
|
||||
}
|
||||
|
||||
@derivative(of: x)
|
||||
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Parameter) {
|
||||
func vjpX() -> (value: Float, pullback: (Float) -> Parameter) {
|
||||
return (x, { dx in Parameter(x: dx) } )
|
||||
}
|
||||
|
||||
@derivative(of: x)
|
||||
func jvpX() -> (value: Tracked<Float>, differential: (Parameter) -> Tracked<Float>) {
|
||||
func jvpX() -> (value: Float, differential: (Parameter) -> Float) {
|
||||
return (x, { $0.x })
|
||||
}
|
||||
}
|
||||
|
||||
extension Parameter {
|
||||
func squared() -> Tracked<Float> {
|
||||
func squared() -> Float {
|
||||
return x * x
|
||||
}
|
||||
|
||||
static func squared(p: Parameter) -> Tracked<Float> {
|
||||
static func squared(p: Parameter) -> Float {
|
||||
return p.x * p.x
|
||||
}
|
||||
|
||||
func multiplied(with other: Tracked<Float>) -> Tracked<Float> {
|
||||
func multiplied(with other: Float) -> Float {
|
||||
return x * other
|
||||
}
|
||||
|
||||
static func * (_ a: Parameter, _ b: Parameter) -> Tracked<Float> {
|
||||
static func * (_ a: Parameter, _ b: Parameter) -> Float {
|
||||
return a.x * b.x
|
||||
}
|
||||
}
|
||||
|
||||
extension Parameter : Differentiable, AdditiveArithmetic {
|
||||
typealias TangentVector = Parameter
|
||||
typealias Scalar = Tracked<Float>
|
||||
typealias Scalar = Float
|
||||
typealias Shape = ()
|
||||
init(repeating repeatedValue: Tracked<Float>, shape: ()) {
|
||||
init(repeating repeatedValue: Float, shape: ()) {
|
||||
self.init(x: repeatedValue)
|
||||
}
|
||||
static func + (lhs: Parameter, rhs: Parameter) -> Parameter {
|
||||
@@ -67,37 +67,185 @@ extension Parameter : Differentiable, AdditiveArithmetic {
|
||||
static var zero: Parameter { return Parameter(x: 0) }
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking(
|
||||
MethodTests.test(
|
||||
"instance method with generated adjoint, called from differentiated func"
|
||||
) {
|
||||
func f(_ p: Parameter) -> Tracked<Float> {
|
||||
func f(_ p: Parameter) -> Float {
|
||||
return 100 * p.squared()
|
||||
}
|
||||
expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f))
|
||||
expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f))
|
||||
}
|
||||
|
||||
MethodTests.test(
|
||||
"instance method with generated adjoint, differentiated directly"
|
||||
) {
|
||||
// This is our current syntax for taking gradients of instance methods
|
||||
// directly. If/when we develop nicer syntax for this, change this test.
|
||||
func g(p: Parameter) -> Float { p.squared() }
|
||||
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: g))
|
||||
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: g))
|
||||
}
|
||||
|
||||
MethodTests.test("instance method with generated adjoint, wrt only self") {
|
||||
func f(_ p: Parameter) -> Float {
|
||||
return 100 * p.multiplied(with: 200)
|
||||
}
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.test("instance method with generated adjoint, wrt only non-self") {
|
||||
func f(_ other: Float) -> Float {
|
||||
return 100 * Parameter(x: 200).multiplied(with: other)
|
||||
}
|
||||
expectEqual(100 * 200, gradient(at: 1, of: f))
|
||||
expectEqual(100 * 200, gradient(at: 2, of: f))
|
||||
}
|
||||
|
||||
MethodTests.test(
|
||||
"instance method with generated adjoint, wrt self and non-self"
|
||||
) {
|
||||
expectEqual(
|
||||
(Parameter(x: 100), 200), gradient(at: Parameter(x: 200), 100) { $0.multiplied(with: $1) })
|
||||
expectEqual(
|
||||
(Parameter(x: 200), 100), gradient(at: Parameter(x: 100), 200) { $0.multiplied(with: $1) })
|
||||
}
|
||||
|
||||
MethodTests.test(
|
||||
"static method with generated adjoint, called from differentiated func"
|
||||
) {
|
||||
func f(_ p: Parameter) -> Float {
|
||||
return 100 * Parameter.squared(p: p)
|
||||
}
|
||||
expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f))
|
||||
expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f))
|
||||
}
|
||||
|
||||
MethodTests.test(
|
||||
"static method with generated adjoint, differentiated directly"
|
||||
) {
|
||||
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: Parameter.squared))
|
||||
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: Parameter.squared))
|
||||
}
|
||||
|
||||
MethodTests.test("static method with generated adjoint, wrt only first param") {
|
||||
func f(_ p: Parameter) -> Float {
|
||||
return 100 * (p * Parameter(x: 200))
|
||||
}
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.test("static method with generated adjoint, wrt only second param") {
|
||||
func f(_ p: Parameter) -> Float {
|
||||
return 100 * (Parameter(x: 200) * p)
|
||||
}
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.test("static method with generated adjoint, wrt all params") {
|
||||
func g(a: Parameter, b: Parameter) -> Float { a * b }
|
||||
expectEqual((Parameter(x: 100), Parameter(x: 200)),
|
||||
gradient(at: Parameter(x: 200), Parameter(x: 100), of: g))
|
||||
expectEqual((Parameter(x: 200), Parameter(x: 100)),
|
||||
gradient(at: Parameter(x: 100), Parameter(x: 200), of: g))
|
||||
}
|
||||
|
||||
|
||||
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||
We cannot use `Tracked<T>` :(
|
||||
struct ParameterTracked : Equatable {
|
||||
private let storedX: Tracked<Float>
|
||||
@differentiable(reverse, wrt: (self))
|
||||
var x: Tracked<Float> {
|
||||
return storedX
|
||||
}
|
||||
|
||||
init(x: Tracked<Float>) {
|
||||
storedX = x
|
||||
}
|
||||
|
||||
@derivative(of: x)
|
||||
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> ParameterTracked) {
|
||||
return (x, { dx in ParameterTracked(x: dx) } )
|
||||
}
|
||||
|
||||
@derivative(of: x)
|
||||
func jvpX() -> (value: Tracked<Float>, differential: (ParameterTracked) -> Tracked<Float>) {
|
||||
return (x, { $0.x })
|
||||
}
|
||||
}
|
||||
|
||||
extension ParameterTracked {
|
||||
func squared() -> Tracked<Float> {
|
||||
return x * x
|
||||
}
|
||||
|
||||
static func squared(p: ParameterTracked) -> Tracked<Float> {
|
||||
return p.x * p.x
|
||||
}
|
||||
|
||||
func multiplied(with other: Tracked<Float>) -> Tracked<Float> {
|
||||
return x * other
|
||||
}
|
||||
|
||||
static func * (_ a: ParameterTracked, _ b: ParameterTracked) -> Tracked<Float> {
|
||||
return a.x * b.x
|
||||
}
|
||||
}
|
||||
|
||||
extension ParameterTracked : Differentiable, AdditiveArithmetic {
|
||||
typealias TangentVector = ParameterTracked
|
||||
typealias Scalar = Tracked<Float>
|
||||
typealias Shape = ()
|
||||
init(repeating repeatedValue: Tracked<Float>, shape: ()) {
|
||||
self.init(x: repeatedValue)
|
||||
}
|
||||
static func + (lhs: ParameterTracked, rhs: ParameterTracked) -> ParameterTracked {
|
||||
return ParameterTracked(x: lhs.x + rhs.x)
|
||||
}
|
||||
static func - (lhs: ParameterTracked, rhs: ParameterTracked) -> ParameterTracked {
|
||||
return ParameterTracked(x: lhs.x - rhs.x)
|
||||
}
|
||||
static func * (lhs: Scalar, rhs: ParameterTracked) -> ParameterTracked {
|
||||
return ParameterTracked(x: lhs * rhs.x)
|
||||
}
|
||||
static var zero: ParameterTracked { return ParameterTracked(x: 0) }
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking(
|
||||
"instance method with generated adjoint, called from differentiated func"
|
||||
) {
|
||||
func f(_ p: ParameterTracked) -> Tracked<Float> {
|
||||
return 100 * p.squared()
|
||||
}
|
||||
expectEqual(ParameterTracked(x: 4 * 100), gradient(at: ParameterTracked(x: 2), of: f))
|
||||
expectEqual(ParameterTracked(x: 40 * 100), gradient(at: ParameterTracked(x: 20), of: f))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking(
|
||||
"instance method with generated adjoint, differentiated directly"
|
||||
) {
|
||||
// This is our current syntax for taking gradients of instance methods
|
||||
// directly. If/when we develop nicer syntax for this, change this test.
|
||||
func g(p: Parameter) -> Tracked<Float> { p.squared() }
|
||||
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: g))
|
||||
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: g))
|
||||
func g(p: ParameterTracked) -> Tracked<Float> { p.squared() }
|
||||
expectEqual(ParameterTracked(x: 4), gradient(at: ParameterTracked(x: 2), of: g))
|
||||
expectEqual(ParameterTracked(x: 40), gradient(at: ParameterTracked(x: 20), of: g))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only self") {
|
||||
func f(_ p: Parameter) -> Tracked<Float> {
|
||||
func f(_ p: ParameterTracked) -> Tracked<Float> {
|
||||
return 100 * p.multiplied(with: 200)
|
||||
}
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
||||
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f))
|
||||
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only non-self") {
|
||||
func f(_ other: Tracked<Float>) -> Tracked<Float> {
|
||||
return 100 * Parameter(x: 200).multiplied(with: other)
|
||||
return 100 * ParameterTracked(x: 200).multiplied(with: other)
|
||||
}
|
||||
expectEqual(100 * 200, gradient(at: 1, of: f))
|
||||
expectEqual(100 * 200, gradient(at: 2, of: f))
|
||||
@@ -107,51 +255,52 @@ MethodTests.testWithLeakChecking(
|
||||
"instance method with generated adjoint, wrt self and non-self"
|
||||
) {
|
||||
expectEqual(
|
||||
(Parameter(x: 100), 200), gradient(at: Parameter(x: 200), 100) { $0.multiplied(with: $1) })
|
||||
(ParameterTracked(x: 100), 200), gradient(at: ParameterTracked(x: 200), 100) { $0.multiplied(with: $1) })
|
||||
expectEqual(
|
||||
(Parameter(x: 200), 100), gradient(at: Parameter(x: 100), 200) { $0.multiplied(with: $1) })
|
||||
(ParameterTracked(x: 200), 100), gradient(at: ParameterTracked(x: 100), 200) { $0.multiplied(with: $1) })
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking(
|
||||
"static method with generated adjoint, called from differentiated func"
|
||||
) {
|
||||
func f(_ p: Parameter) -> Tracked<Float> {
|
||||
return 100 * Parameter.squared(p: p)
|
||||
func f(_ p: ParameterTracked) -> Tracked<Float> {
|
||||
return 100 * ParameterTracked.squared(p: p)
|
||||
}
|
||||
expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f))
|
||||
expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f))
|
||||
expectEqual(ParameterTracked(x: 4 * 100), gradient(at: ParameterTracked(x: 2), of: f))
|
||||
expectEqual(ParameterTracked(x: 40 * 100), gradient(at: ParameterTracked(x: 20), of: f))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking(
|
||||
"static method with generated adjoint, differentiated directly"
|
||||
) {
|
||||
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: Parameter.squared))
|
||||
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: Parameter.squared))
|
||||
expectEqual(ParameterTracked(x: 4), gradient(at: ParameterTracked(x: 2), of: ParameterTracked.squared))
|
||||
expectEqual(ParameterTracked(x: 40), gradient(at: ParameterTracked(x: 20), of: ParameterTracked.squared))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only first param") {
|
||||
func f(_ p: Parameter) -> Tracked<Float> {
|
||||
return 100 * (p * Parameter(x: 200))
|
||||
func f(_ p: ParameterTracked) -> Tracked<Float> {
|
||||
return 100 * (p * ParameterTracked(x: 200))
|
||||
}
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
||||
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f))
|
||||
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only second param") {
|
||||
func f(_ p: Parameter) -> Tracked<Float> {
|
||||
return 100 * (Parameter(x: 200) * p)
|
||||
func f(_ p: ParameterTracked) -> Tracked<Float> {
|
||||
return 100 * (ParameterTracked(x: 200) * p)
|
||||
}
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
|
||||
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
|
||||
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f))
|
||||
expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt all params") {
|
||||
func g(a: Parameter, b: Parameter) -> Tracked<Float> { a * b }
|
||||
expectEqual((Parameter(x: 100), Parameter(x: 200)),
|
||||
gradient(at: Parameter(x: 200), Parameter(x: 100), of: g))
|
||||
expectEqual((Parameter(x: 200), Parameter(x: 100)),
|
||||
gradient(at: Parameter(x: 100), Parameter(x: 200), of: g))
|
||||
func g(a: ParameterTracked, b: ParameterTracked) -> Tracked<Float> { a * b }
|
||||
expectEqual((ParameterTracked(x: 100), ParameterTracked(x: 200)),
|
||||
gradient(at: ParameterTracked(x: 200), ParameterTracked(x: 100), of: g))
|
||||
expectEqual((ParameterTracked(x: 200), ParameterTracked(x: 100)),
|
||||
gradient(at: ParameterTracked(x: 100), ParameterTracked(x: 200), of: g))
|
||||
}
|
||||
*/
|
||||
|
||||
// ==== Tests with custom adjoint ====
|
||||
|
||||
@@ -174,27 +323,27 @@ struct DiffWrtSelf : Differentiable {
|
||||
}
|
||||
|
||||
struct CustomParameter : Equatable {
|
||||
let storedX: Tracked<Float>
|
||||
let storedX: Float
|
||||
@differentiable(reverse, wrt: (self))
|
||||
var x: Tracked<Float> {
|
||||
var x: Float {
|
||||
return storedX
|
||||
}
|
||||
|
||||
init(x: Tracked<Float>) {
|
||||
init(x: Float) {
|
||||
storedX = x
|
||||
}
|
||||
|
||||
@derivative(of: x)
|
||||
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) {
|
||||
func vjpX() -> (value: Float, pullback: (Float) -> CustomParameter) {
|
||||
return (x, { dx in CustomParameter(x: dx) })
|
||||
}
|
||||
}
|
||||
|
||||
extension CustomParameter : Differentiable, AdditiveArithmetic {
|
||||
typealias TangentVector = CustomParameter
|
||||
typealias Scalar = Tracked<Float>
|
||||
typealias Scalar = Float
|
||||
typealias Shape = ()
|
||||
init(repeating repeatedValue: Tracked<Float>, shape: ()) {
|
||||
init(repeating repeatedValue: Float, shape: ()) {
|
||||
self.init(x: repeatedValue)
|
||||
}
|
||||
static func + (lhs: CustomParameter, rhs: CustomParameter) -> CustomParameter {
|
||||
@@ -209,38 +358,256 @@ extension CustomParameter : Differentiable, AdditiveArithmetic {
|
||||
static var zero: CustomParameter { return CustomParameter(x: 0) }
|
||||
}
|
||||
|
||||
extension Tracked where T : FloatingPoint {
|
||||
func clamped(to limits: ClosedRange<Tracked<T>>) -> Tracked<T> {
|
||||
extension Float {
|
||||
func clamped(to limits: ClosedRange<Float>) -> Float {
|
||||
return min(max(self, limits.lowerBound), limits.upperBound)
|
||||
}
|
||||
}
|
||||
|
||||
extension CustomParameter {
|
||||
@differentiable(reverse, wrt: (self))
|
||||
func squared() -> Tracked<Float> {
|
||||
func squared() -> Float {
|
||||
return x * x
|
||||
}
|
||||
|
||||
@derivative(of: squared)
|
||||
func dSquared() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) {
|
||||
func dSquared() -> (value: Float, pullback: (Float) -> CustomParameter) {
|
||||
return (squared(), { [x] v in CustomParameter(x: (2 * x).clamped(to: -10.0...10.0) * v) })
|
||||
}
|
||||
|
||||
@differentiable(reverse)
|
||||
static func squared(p: CustomParameter) -> Tracked<Float> {
|
||||
static func squared(p: CustomParameter) -> Float {
|
||||
return p.x * p.x
|
||||
}
|
||||
|
||||
@derivative(of: squared)
|
||||
static func dSquared(
|
||||
_ p: CustomParameter
|
||||
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) {
|
||||
) -> (value: Float, pullback: (Float) -> CustomParameter) {
|
||||
return (p.x * p.x, { v in CustomParameter(x: (2 * p.x).clamped(to: -10.0...10.0) * v) })
|
||||
}
|
||||
|
||||
// There is currently no way to define multiple custom VJPs wrt different
|
||||
// parameters on the same func, so we define a copy of this func per adjoint.
|
||||
|
||||
@differentiable(reverse, wrt: (self, other))
|
||||
func multiplied(with other: Float) -> Float {
|
||||
return x * other
|
||||
}
|
||||
|
||||
@differentiable(reverse, wrt: (other))
|
||||
func multiplied_constSelf(with other: Float) -> Float {
|
||||
return x * other
|
||||
}
|
||||
|
||||
@differentiable(reverse, wrt: (self))
|
||||
func multiplied_constOther(with other: Float) -> Float {
|
||||
return x * other
|
||||
}
|
||||
|
||||
@derivative(of: multiplied)
|
||||
func dMultiplied_wrtAll(
|
||||
with other: Float
|
||||
) -> (value: Float, pullback: (Float) -> (CustomParameter, Float)) {
|
||||
return (multiplied(with: other),
|
||||
{ [x] v in (CustomParameter(x: other.clamped(to: -10.0...10.0) * v),
|
||||
x.clamped(to: -10.0...10.0) * v) })
|
||||
}
|
||||
|
||||
@derivative(of: multiplied_constSelf, wrt: other)
|
||||
func dMultiplied_wrtOther(
|
||||
with other: Float
|
||||
) -> (value: Float, pullback: (Float) -> Float) {
|
||||
let (r, pb) = dMultiplied_wrtAll(with: other)
|
||||
return (r, { v in pb(v).1 })
|
||||
}
|
||||
|
||||
@derivative(of: multiplied_constOther, wrt: self)
|
||||
func dMultiplied_wrtSelf(
|
||||
with other: Float
|
||||
) -> (value: Float, pullback: (Float) -> CustomParameter) {
|
||||
let (r, pb) = dMultiplied_wrtAll(with: other)
|
||||
return (r, { v in pb(v).0 })
|
||||
}
|
||||
|
||||
@differentiable(reverse)
|
||||
static func multiply(_ lhs: CustomParameter, _ rhs: CustomParameter)
|
||||
-> Float {
|
||||
return lhs.x * rhs.x
|
||||
}
|
||||
|
||||
@differentiable(reverse, wrt: (rhs))
|
||||
static func multiply_constLhs(_ lhs: CustomParameter, _ rhs: CustomParameter) -> Float {
|
||||
return lhs.x * rhs.x
|
||||
}
|
||||
|
||||
@derivative(of: multiply)
|
||||
static func dMultiply_wrtAll(_ lhs: CustomParameter,_ rhs: CustomParameter)
|
||||
-> (value: Float, pullback: (Float) -> (CustomParameter, CustomParameter)) {
|
||||
let result = multiply(lhs, rhs)
|
||||
return (result, { v in (CustomParameter(x: rhs.x.clamped(to: -10.0...10.0) * v),
|
||||
CustomParameter(x: lhs.x.clamped(to: -10.0...10.0) * v)) })
|
||||
}
|
||||
|
||||
@derivative(of: multiply_constLhs, wrt: rhs)
|
||||
static func dMultiply_wrtRhs(_ lhs: CustomParameter, _ rhs: CustomParameter)
|
||||
-> (value: Float, pullback: (Float) -> CustomParameter) {
|
||||
let (r, pb) = dMultiply_wrtAll(lhs, rhs)
|
||||
return (r, { v in pb(v).1 })
|
||||
}
|
||||
}
|
||||
|
||||
MethodTests.test(
|
||||
"instance method with custom adjoint, called from differentiated func"
|
||||
) {
|
||||
func f(_ p: CustomParameter) -> Float {
|
||||
return 100 * p.squared()
|
||||
}
|
||||
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f))
|
||||
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f))
|
||||
}
|
||||
|
||||
MethodTests.test("instance method with generated adjoint, differentiated directly") {
|
||||
// This is our current syntax for taking gradients of instance methods
|
||||
// directly. If/when we develop nicer syntax for this, change this test.
|
||||
func g(p: CustomParameter) -> Float { p.squared() }
|
||||
expectEqual(CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: g))
|
||||
expectEqual(CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: g))
|
||||
}
|
||||
|
||||
MethodTests.test("static method with custom adjoint, called from differentiated func") {
|
||||
func f(_ p: CustomParameter) -> Float {
|
||||
return 100 * CustomParameter.squared(p: p)
|
||||
}
|
||||
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f))
|
||||
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f))
|
||||
}
|
||||
|
||||
MethodTests.test("static method with custom adjoint, differentiated directly") {
|
||||
expectEqual(
|
||||
CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: CustomParameter.squared))
|
||||
expectEqual(
|
||||
CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: CustomParameter.squared))
|
||||
}
|
||||
|
||||
MethodTests.test("instance method with custom adjoint, wrt only self") {
|
||||
func f(_ p: CustomParameter) -> Float {
|
||||
return 100 * p.multiplied_constOther(with: 200)
|
||||
}
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.test("instance method with custom adjoint, wrt only non-self") {
|
||||
func f(_ other: Float) -> Float {
|
||||
return 100 * CustomParameter(x: 200).multiplied_constSelf(with: other)
|
||||
}
|
||||
expectEqual(100 * 10, gradient(at: 1, of: f))
|
||||
expectEqual(100 * 10, gradient(at: 2, of: f))
|
||||
}
|
||||
|
||||
MethodTests.test("instance method with custom adjoint, wrt self and non-self") {
|
||||
func g(p: CustomParameter, o: Float) -> Float { p.multiplied(with: o) }
|
||||
expectEqual((CustomParameter(x: 5), 10), gradient(at: CustomParameter(x: 100), 5, of: g))
|
||||
expectEqual((CustomParameter(x: 10), 5), gradient(at: CustomParameter(x: 5), 100, of: g))
|
||||
}
|
||||
|
||||
MethodTests.test("static method with custom adjoint, wrt only lhs") {
|
||||
func f(_ p: CustomParameter) -> Float {
|
||||
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p)
|
||||
}
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.test("static method with custom adjoint, wrt only rhs") {
|
||||
func f(_ p: CustomParameter) -> Float {
|
||||
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p)
|
||||
}
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.test("static method with custom adjoint, wrt all") {
|
||||
func f(_ a: CustomParameter, _ b: CustomParameter) -> Float {
|
||||
return CustomParameter.multiply(a, b)
|
||||
}
|
||||
expectEqual((CustomParameter(x: 5), CustomParameter(x: 10)),
|
||||
gradient(at: CustomParameter(x: 100), CustomParameter(x: 5), of: f))
|
||||
expectEqual((CustomParameter(x: 10), CustomParameter(x: 5)),
|
||||
gradient(at: CustomParameter(x: 5), CustomParameter(x: 100), of: f))
|
||||
}
|
||||
|
||||
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||
We cannot use `Tracked<T>` :(
|
||||
struct CustomParameterTracked : Equatable {
|
||||
let storedX: Tracked<Float>
|
||||
@differentiable(reverse, wrt: (self))
|
||||
var x: Tracked<Float> {
|
||||
return storedX
|
||||
}
|
||||
|
||||
init(x: Tracked<Float>) {
|
||||
storedX = x
|
||||
}
|
||||
|
||||
@derivative(of: x)
|
||||
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
|
||||
return (x, { dx in CustomParameterTracked(x: dx) })
|
||||
}
|
||||
}
|
||||
|
||||
extension CustomParameterTracked : Differentiable, AdditiveArithmetic {
|
||||
typealias TangentVector = CustomParameterTracked
|
||||
typealias Scalar = Tracked<Float>
|
||||
typealias Shape = ()
|
||||
init(repeating repeatedValue: Tracked<Float>, shape: ()) {
|
||||
self.init(x: repeatedValue)
|
||||
}
|
||||
static func + (lhs: CustomParameterTracked, rhs: CustomParameterTracked) -> CustomParameterTracked {
|
||||
return CustomParameterTracked(x: lhs.x + rhs.x)
|
||||
}
|
||||
static func - (lhs: CustomParameterTracked, rhs: CustomParameterTracked) -> CustomParameterTracked {
|
||||
return CustomParameterTracked(x: lhs.x - rhs.x)
|
||||
}
|
||||
static func * (lhs: Scalar, rhs: CustomParameterTracked) -> CustomParameterTracked {
|
||||
return CustomParameterTracked(x: lhs * rhs.x)
|
||||
}
|
||||
static var zero: CustomParameterTracked { return CustomParameterTracked(x: 0) }
|
||||
}
|
||||
|
||||
extension Tracked where T : FloatingPoint {
|
||||
func clamped(to limits: ClosedRange<Tracked<T>>) -> Tracked<T> {
|
||||
return min(max(self, limits.lowerBound), limits.upperBound)
|
||||
}
|
||||
}
|
||||
|
||||
extension CustomParameterTracked {
|
||||
@differentiable(reverse, wrt: (self))
|
||||
func squared() -> Tracked<Float> {
|
||||
return x * x
|
||||
}
|
||||
|
||||
@derivative(of: squared)
|
||||
func dSquared() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
|
||||
return (squared(), { [x] v in CustomParameterTracked(x: (2 * x).clamped(to: -10.0...10.0) * v) })
|
||||
}
|
||||
|
||||
@differentiable(reverse)
|
||||
static func squared(p: CustomParameterTracked) -> Tracked<Float> {
|
||||
return p.x * p.x
|
||||
}
|
||||
|
||||
@derivative(of: squared)
|
||||
static func dSquared(
|
||||
_ p: CustomParameterTracked
|
||||
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
|
||||
return (p.x * p.x, { v in CustomParameterTracked(x: (2 * p.x).clamped(to: -10.0...10.0) * v) })
|
||||
}
|
||||
|
||||
// There is currently no way to define multiple custom VJPs wrt different
|
||||
// parameters on the same func, so we define a copy of this func per adjoint.
|
||||
|
||||
@differentiable(reverse, wrt: (self, other))
|
||||
func multiplied(with other: Tracked<Float>) -> Tracked<Float> {
|
||||
return x * other
|
||||
@@ -259,9 +626,9 @@ extension CustomParameter {
|
||||
@derivative(of: multiplied)
|
||||
func dMultiplied_wrtAll(
|
||||
with other: Tracked<Float>
|
||||
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameter, Tracked<Float>)) {
|
||||
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameterTracked, Tracked<Float>)) {
|
||||
return (multiplied(with: other),
|
||||
{ [x] v in (CustomParameter(x: other.clamped(to: -10.0...10.0) * v),
|
||||
{ [x] v in (CustomParameterTracked(x: other.clamped(to: -10.0...10.0) * v),
|
||||
x.clamped(to: -10.0...10.0) * v) })
|
||||
}
|
||||
|
||||
@@ -276,33 +643,33 @@ extension CustomParameter {
|
||||
@derivative(of: multiplied_constOther, wrt: self)
|
||||
func dMultiplied_wrtSelf(
|
||||
with other: Tracked<Float>
|
||||
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) {
|
||||
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
|
||||
let (r, pb) = dMultiplied_wrtAll(with: other)
|
||||
return (r, { v in pb(v).0 })
|
||||
}
|
||||
|
||||
@differentiable(reverse)
|
||||
static func multiply(_ lhs: CustomParameter, _ rhs: CustomParameter)
|
||||
static func multiply(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked)
|
||||
-> Tracked<Float> {
|
||||
return lhs.x * rhs.x
|
||||
}
|
||||
|
||||
@differentiable(reverse, wrt: (rhs))
|
||||
static func multiply_constLhs(_ lhs: CustomParameter, _ rhs: CustomParameter) -> Tracked<Float> {
|
||||
static func multiply_constLhs(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked) -> Tracked<Float> {
|
||||
return lhs.x * rhs.x
|
||||
}
|
||||
|
||||
@derivative(of: multiply)
|
||||
static func dMultiply_wrtAll(_ lhs: CustomParameter,_ rhs: CustomParameter)
|
||||
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameter, CustomParameter)) {
|
||||
static func dMultiply_wrtAll(_ lhs: CustomParameterTracked,_ rhs: CustomParameterTracked)
|
||||
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameterTracked, CustomParameterTracked)) {
|
||||
let result = multiply(lhs, rhs)
|
||||
return (result, { v in (CustomParameter(x: rhs.x.clamped(to: -10.0...10.0) * v),
|
||||
CustomParameter(x: lhs.x.clamped(to: -10.0...10.0) * v)) })
|
||||
return (result, { v in (CustomParameterTracked(x: rhs.x.clamped(to: -10.0...10.0) * v),
|
||||
CustomParameterTracked(x: lhs.x.clamped(to: -10.0...10.0) * v)) })
|
||||
}
|
||||
|
||||
@derivative(of: multiply_constLhs, wrt: rhs)
|
||||
static func dMultiply_wrtRhs(_ lhs: CustomParameter, _ rhs: CustomParameter)
|
||||
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) {
|
||||
static func dMultiply_wrtRhs(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked)
|
||||
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
|
||||
let (r, pb) = dMultiply_wrtAll(lhs, rhs)
|
||||
return (r, { v in pb(v).1 })
|
||||
}
|
||||
@@ -311,82 +678,83 @@ extension CustomParameter {
|
||||
MethodTests.testWithLeakChecking(
|
||||
"instance method with custom adjoint, called from differentiated func"
|
||||
) {
|
||||
func f(_ p: CustomParameter) -> Tracked<Float> {
|
||||
func f(_ p: CustomParameterTracked) -> Tracked<Float> {
|
||||
return 100 * p.squared()
|
||||
}
|
||||
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f))
|
||||
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f))
|
||||
expectEqual(CustomParameterTracked(x: 4 * 100), gradient(at: CustomParameterTracked(x: 2), of: f))
|
||||
expectEqual(CustomParameterTracked(x: 10 * 100), gradient(at: CustomParameterTracked(x: 20), of: f))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("instance method with generated adjoint, differentiated directly") {
|
||||
// This is our current syntax for taking gradients of instance methods
|
||||
// directly. If/when we develop nicer syntax for this, change this test.
|
||||
func g(p: CustomParameter) -> Tracked<Float> { p.squared() }
|
||||
expectEqual(CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: g))
|
||||
expectEqual(CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: g))
|
||||
func g(p: CustomParameterTracked) -> Tracked<Float> { p.squared() }
|
||||
expectEqual(CustomParameterTracked(x: 4), gradient(at: CustomParameterTracked(x: 2), of: g))
|
||||
expectEqual(CustomParameterTracked(x: 10), gradient(at: CustomParameterTracked(x: 20), of: g))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("static method with custom adjoint, called from differentiated func") {
|
||||
func f(_ p: CustomParameter) -> Tracked<Float> {
|
||||
return 100 * CustomParameter.squared(p: p)
|
||||
func f(_ p: CustomParameterTracked) -> Tracked<Float> {
|
||||
return 100 * CustomParameterTracked.squared(p: p)
|
||||
}
|
||||
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f))
|
||||
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f))
|
||||
expectEqual(CustomParameterTracked(x: 4 * 100), gradient(at: CustomParameterTracked(x: 2), of: f))
|
||||
expectEqual(CustomParameterTracked(x: 10 * 100), gradient(at: CustomParameterTracked(x: 20), of: f))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("static method with custom adjoint, differentiated directly") {
|
||||
expectEqual(
|
||||
CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: CustomParameter.squared))
|
||||
CustomParameterTracked(x: 4), gradient(at: CustomParameterTracked(x: 2), of: CustomParameterTracked.squared))
|
||||
expectEqual(
|
||||
CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: CustomParameter.squared))
|
||||
CustomParameterTracked(x: 10), gradient(at: CustomParameterTracked(x: 20), of: CustomParameterTracked.squared))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only self") {
|
||||
func f(_ p: CustomParameter) -> Tracked<Float> {
|
||||
func f(_ p: CustomParameterTracked) -> Tracked<Float> {
|
||||
return 100 * p.multiplied_constOther(with: 200)
|
||||
}
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
||||
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f))
|
||||
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only non-self") {
|
||||
func f(_ other: Tracked<Float>) -> Tracked<Float> {
|
||||
return 100 * CustomParameter(x: 200).multiplied_constSelf(with: other)
|
||||
return 100 * CustomParameterTracked(x: 200).multiplied_constSelf(with: other)
|
||||
}
|
||||
expectEqual(100 * 10, gradient(at: 1, of: f))
|
||||
expectEqual(100 * 10, gradient(at: 2, of: f))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt self and non-self") {
|
||||
func g(p: CustomParameter, o: Tracked<Float>) -> Tracked<Float> { p.multiplied(with: o) }
|
||||
expectEqual((CustomParameter(x: 5), 10), gradient(at: CustomParameter(x: 100), 5, of: g))
|
||||
expectEqual((CustomParameter(x: 10), 5), gradient(at: CustomParameter(x: 5), 100, of: g))
|
||||
func g(p: CustomParameterTracked, o: Tracked<Float>) -> Tracked<Float> { p.multiplied(with: o) }
|
||||
expectEqual((CustomParameterTracked(x: 5), 10), gradient(at: CustomParameterTracked(x: 100), 5, of: g))
|
||||
expectEqual((CustomParameterTracked(x: 10), 5), gradient(at: CustomParameterTracked(x: 5), 100, of: g))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only lhs") {
|
||||
func f(_ p: CustomParameter) -> Tracked<Float> {
|
||||
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p)
|
||||
func f(_ p: CustomParameterTracked) -> Tracked<Float> {
|
||||
return 100 * CustomParameterTracked.multiply_constLhs(CustomParameterTracked(x: 200), p)
|
||||
}
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
||||
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f))
|
||||
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only rhs") {
|
||||
func f(_ p: CustomParameter) -> Tracked<Float> {
|
||||
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p)
|
||||
func f(_ p: CustomParameterTracked) -> Tracked<Float> {
|
||||
return 100 * CustomParameterTracked.multiply_constLhs(CustomParameterTracked(x: 200), p)
|
||||
}
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
|
||||
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
|
||||
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f))
|
||||
expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f))
|
||||
}
|
||||
|
||||
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt all") {
|
||||
func f(_ a: CustomParameter, _ b: CustomParameter) -> Tracked<Float> {
|
||||
return CustomParameter.multiply(a, b)
|
||||
func f(_ a: CustomParameterTracked, _ b: CustomParameterTracked) -> Tracked<Float> {
|
||||
return CustomParameterTracked.multiply(a, b)
|
||||
}
|
||||
expectEqual((CustomParameter(x: 5), CustomParameter(x: 10)),
|
||||
gradient(at: CustomParameter(x: 100), CustomParameter(x: 5), of: f))
|
||||
expectEqual((CustomParameter(x: 10), CustomParameter(x: 5)),
|
||||
gradient(at: CustomParameter(x: 5), CustomParameter(x: 100), of: f))
|
||||
expectEqual((CustomParameterTracked(x: 5), CustomParameterTracked(x: 10)),
|
||||
gradient(at: CustomParameterTracked(x: 100), CustomParameterTracked(x: 5), of: f))
|
||||
expectEqual((CustomParameterTracked(x: 10), CustomParameterTracked(x: 5)),
|
||||
gradient(at: CustomParameterTracked(x: 5), CustomParameterTracked(x: 100), of: f))
|
||||
}
|
||||
*/
|
||||
|
||||
runAllTests()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: %target-run-simple-swift
|
||||
1// RUN: %target-run-simple-swift
|
||||
// REQUIRES: executable_test
|
||||
|
||||
import StdlibUnittest
|
||||
@@ -6,7 +6,19 @@ import DifferentiationUnittest
|
||||
|
||||
var RepeatedCallsTests = TestSuite("RepeatedCalls")
|
||||
|
||||
RepeatedCallsTests.testWithLeakChecking("Repeat") {
|
||||
RepeatedCallsTests.test("Repeat") {
|
||||
func mul2(_ x: Float) -> Float {
|
||||
return 2 * x
|
||||
}
|
||||
func mul4(_ x: Float) -> Float {
|
||||
return mul2(mul2(x))
|
||||
}
|
||||
expectEqual(4, gradient(at: 0, of: mul4))
|
||||
}
|
||||
|
||||
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||
We cannot use `Tracked<T>` :(
|
||||
RepeatedCallsTests.testWithLeakChecking("Repeat-Tracked") {
|
||||
func mul2(_ x: Tracked<Float>) -> Tracked<Float> {
|
||||
return 2 * x
|
||||
}
|
||||
@@ -15,5 +27,6 @@ RepeatedCallsTests.testWithLeakChecking("Repeat") {
|
||||
}
|
||||
expectEqual(4, gradient(at: 0, of: mul4))
|
||||
}
|
||||
*/
|
||||
|
||||
runAllTests()
|
||||
|
||||
@@ -266,6 +266,9 @@ SimpleMathTests.test("TupleMutation") {
|
||||
}
|
||||
|
||||
// Tests TF-321.
|
||||
|
||||
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
|
||||
We cannot use `Tracked<T>` :(
|
||||
SimpleMathTests.test("TupleNonDifferentiableElements") {
|
||||
// TF-964: Test tuple with non-tuple-typed adjoint value.
|
||||
func tupleLet(_ x: Tracked<Float>) -> Tracked<Float> {
|
||||
@@ -309,6 +312,51 @@ SimpleMathTests.test("TupleNonDifferentiableElements") {
|
||||
}
|
||||
expectEqual((3, 1), valueWithGradient(at: 3, of: wrapper))
|
||||
}
|
||||
*/
|
||||
|
||||
SimpleMathTests.test("TupleNonDifferentiableElementsNotTracked") {
|
||||
// TF-964: Test tuple with non-tuple-typed adjoint value.
|
||||
func tupleLet(_ x: Float) -> Float {
|
||||
let tuple = (2 * x, 1)
|
||||
return tuple.0
|
||||
}
|
||||
expectEqual((8, 2), valueWithGradient(at: 4, of: tupleLet))
|
||||
|
||||
func tupleVar(_ x: Float) -> Float {
|
||||
var tuple = (x, 1)
|
||||
tuple.0 = x
|
||||
tuple.1 = 1
|
||||
return tuple.0
|
||||
}
|
||||
expectEqual((3, 1), valueWithGradient(at: 3, of: tupleVar))
|
||||
|
||||
func nested(_ x: Float) -> Float {
|
||||
// Convoluted function computing `x * x`.
|
||||
var tuple: (Int, (Int, Float), Float) = (1, (1, 0), 0)
|
||||
tuple.0 = 1
|
||||
tuple.1.0 = 1
|
||||
tuple.1.1 = x
|
||||
tuple.2 = x
|
||||
return tuple.1.1 * tuple.2
|
||||
}
|
||||
expectEqual((16, 8), valueWithGradient(at: 4, of: nested))
|
||||
|
||||
struct Wrapper<T> {
|
||||
@differentiable(reverse where T : Differentiable)
|
||||
func baz(_ x: T) -> T {
|
||||
var tuple = (1, 1, x, 1)
|
||||
tuple.0 = 1
|
||||
tuple.2 = x
|
||||
tuple.3 = 1
|
||||
return tuple.2
|
||||
}
|
||||
}
|
||||
func wrapper(_ x: Float) -> Float {
|
||||
let w = Wrapper<Float>()
|
||||
return w.baz(x)
|
||||
}
|
||||
expectEqual((3, 1), valueWithGradient(at: 3, of: wrapper))
|
||||
}
|
||||
|
||||
// Tests TF-21.
|
||||
SimpleMathTests.test("StructMemberwiseInitializer") {
|
||||
|
||||
@@ -21,9 +21,10 @@ func test(arr: [any P]) {
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure that we don't pick a concrete `(AnyHashable, AnyHashable) -> Bool` overload.
|
||||
|
||||
// CHECK: sil private [ossa] @$s34anyhashable_and_operator_filtering4test3arrySayAA1P_pG_tFyAaD_pXEfU_
|
||||
// CHECK: [[LHS_ARG:%.*]] = alloc_stack $E
|
||||
// CHECK: [[RHS_ARG:%.*]] = alloc_stack $E
|
||||
// CHECK: function_ref == infix<A>(_:_:)
|
||||
// CHECK-NEXT: [[GENERIC_OP:%.*]] = function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF : $@convention(thin) <τ_0_0 where τ_0_0 : RawRepresentable, τ_0_0.RawValue : Equatable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> Bool
|
||||
// CHECK-NEXT: apply [[GENERIC_OP]]<E>([[LHS_ARG]], [[RHS_ARG]])
|
||||
// CHECK: [[GENERIC_OP:%.*]] = witness_method $E, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||
// CHECK-NEXT: apply [[GENERIC_OP]]<E>([[LHS_ARG]], [[RHS_ARG]], {{.*}})
|
||||
|
||||
@@ -8,9 +8,8 @@ struct Value: Equatable, ExpressibleByNilLiteral {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: sil hidden [ossa] @$s13rdar1580631514test1vyAA5ValueV_tF : $@convention(thin) (Value) -> ()
|
||||
// function_ref static Value.__derived_struct_equals(_:_:)
|
||||
// CHECK: [[EQUALS_REF:%.*]] = function_ref @$s13rdar1580631515ValueV23__derived_struct_equalsySbAC_ACtFZ
|
||||
// CHECK-NEXT: apply [[EQUALS_REF]](%0, {{.*}})
|
||||
// CHECK: [[EQUALS_REF:%.*]] = witness_method $Value, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||
// CHECK-NEXT: apply [[EQUALS_REF]]<Value>({{.*}})
|
||||
func test(v: Value) {
|
||||
_ = v == nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,6 @@ func testFoo() {
|
||||
// CHECK: [[@LINE+1]]:7 | instance-method/Swift | hash(into:) | s:14swift_ide_test9CustomFooV4hash4intoys6HasherVz_tF | {{.*}}Ref
|
||||
f.hash(into: &hasher)
|
||||
hasher.finalize()
|
||||
// CHECK: [[@LINE+1]]:11 | static-method/Swift | __derived_struct_equals(_:_:) | s:14swift_ide_test9CustomFooV23__derived_struct_equalsySbAC_ACtFZ | {{.*}}Ref
|
||||
// CHECK: [[@LINE+1]]:11 | static-method/infix-operator/Swift | ==(_:_:) | s:SQ2eeoiySbx_xtFZ | {{.*}}Ref
|
||||
_ = f == CustomFoo(a: 0, b: "b")
|
||||
}
|
||||
|
||||
@@ -48,15 +48,15 @@ public enum Alphabet : String {
|
||||
|
||||
// CHECK-LABEL: sil [ossa] @$s4main14check_alphabetySiAA8AlphabetOF : $@convention(thin) (Alphabet) -> Int {
|
||||
public func check_alphabet(_ state : Alphabet) -> Int {
|
||||
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
||||
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
||||
// FRAGILE: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||
// RESILIENT: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||
return state == .E ? 1 : 0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: sil [ossa] @$s4main9compareItySbAA8AlphabetO_ADtF : $@convention(thin) (Alphabet, Alphabet) -> Bool {
|
||||
public func compareIt(_ state : Alphabet, _ rhs: Alphabet) -> Bool {
|
||||
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
||||
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
||||
// FRAGILE: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||
// RESILIENT: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||
return state == rhs
|
||||
}
|
||||
|
||||
@@ -67,14 +67,14 @@ public enum AlphabetInt : Int {
|
||||
|
||||
// CHECK-LABEL: sil [ossa] @$s4main18check_alphabet_intySiAA11AlphabetIntOF : $@convention(thin) (AlphabetInt) -> Int {
|
||||
public func check_alphabet_int(_ state : AlphabetInt) -> Int {
|
||||
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
||||
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
||||
// FRAGILE: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||
// RESILIENT: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||
return state == .E ? 1 : 0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: sil [ossa] @$s4main9compareItySbAA11AlphabetIntO_ADtF : $@convention(thin) (AlphabetInt, AlphabetInt) -> Bool {
|
||||
public func compareIt(_ state : AlphabetInt, _ rhs: AlphabetInt) -> Bool {
|
||||
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
||||
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF
|
||||
// FRAGILE: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||
// RESILIENT: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
|
||||
return state == rhs
|
||||
}
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
|
||||
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=SILGEN
|
||||
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s --check-prefix=OPTIMIZED
|
||||
|
||||
// Operators are no longer devirtualized at AST level, it's done during SIL optimization.
|
||||
|
||||
infix operator +++
|
||||
|
||||
@@ -11,9 +14,13 @@ struct Branch : Twig {
|
||||
static func doIt(_: Branch, _: Branch) {}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () {
|
||||
// CHECK: function_ref @$s18protocol_operators6BranchV4doItyyAC_ACtFZ : $@convention(method) (Branch, Branch, @thin Branch.Type) -> ()
|
||||
// CHECK: return
|
||||
// SILGEN-LABEL: sil hidden [ossa] @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () {
|
||||
// SILGEN: witness_method $Branch, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> ()
|
||||
// SILGEN: return
|
||||
|
||||
// OPTIMIZED-LABEL: sil hidden @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () {
|
||||
// OPTIMIZED: function_ref @$s18protocol_operators6BranchV4doItyyAC_ACtFZ
|
||||
// OPTIMIZED: return
|
||||
func useBranch(_ b: Branch) {
|
||||
b +++ b
|
||||
}
|
||||
@@ -28,11 +35,17 @@ class Stuck : Stick, ExpressibleByIntegerLiteral {
|
||||
required init(integerLiteral: Int) {}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () {
|
||||
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||
// CHECK: witness_method $Stuck, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||
// CHECK: return
|
||||
// SILGEN-LABEL: sil hidden [ossa] @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () {
|
||||
// SILGEN: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||
// SILGEN: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||
// SILGEN: witness_method $Stuck, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||
// SILGEN: return
|
||||
|
||||
// OPTIMIZED-LABEL: sil hidden @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () {
|
||||
// OPTIMIZED: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||
// OPTIMIZED: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||
// OPTIMIZED: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
|
||||
// OPTIMIZED: return
|
||||
func useStick(_ a: Stuck, _ b: Stick) {
|
||||
_ = a +++ b
|
||||
_ = b +++ b
|
||||
@@ -49,10 +62,15 @@ class Rope : Twine<Int>, ExpressibleByIntegerLiteral {
|
||||
required init(integerLiteral: Int) {}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () {
|
||||
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||
// CHECK: witness_method $Rope, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||
// SILGEN-LABEL: sil hidden [ossa] @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () {
|
||||
// SILGEN: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||
// SILGEN: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||
// SILGEN: witness_method $Rope, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||
|
||||
// OPTIMIZED-LABEL: sil hidden @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () {
|
||||
// OPTIMIZED: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||
// OPTIMIZED: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||
// OPTIMIZED: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
|
||||
func useRope(_ r: Rope, _ s: Rope) {
|
||||
_ = r +++ s
|
||||
_ = s +++ s
|
||||
|
||||
Reference in New Issue
Block a user