Merge pull request #84800 from xedin/remove-csapply-operator-devirt

[CSApply] Don't attempt operator devirtualization
This commit is contained in:
Pavel Yaskevich
2025-10-18 23:09:23 +09:00
committed by GitHub
16 changed files with 989 additions and 367 deletions

View File

@@ -43,6 +43,7 @@
#include "swift/SILOptimizer/Differentiation/VJPCloner.h" #include "swift/SILOptimizer/Differentiation/VJPCloner.h"
#include "swift/SILOptimizer/PassManager/Passes.h" #include "swift/SILOptimizer/PassManager/Passes.h"
#include "swift/SILOptimizer/PassManager/Transforms.h" #include "swift/SILOptimizer/PassManager/Transforms.h"
#include "swift/SILOptimizer/Utils/Devirtualize.h"
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h" #include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
#include "llvm/ADT/APSInt.h" #include "llvm/ADT/APSInt.h"
@@ -100,6 +101,24 @@ private:
SILBuilder &builder, SILLocation loc, SILBuilder &builder, SILLocation loc,
DifferentiationInvoker invoker); DifferentiationInvoker invoker);
/// Creates and canonicalizes private differentiability witness for
/// `originalFn`, differentiating with respect to
/// `desiredParameterIndices`. Returns `nullptr` on failure signifying that a
/// diagnostic has been emitted using `invoker`.
SILDifferentiabilityWitness *createPrivateDifferentiabilityWitness(
CanSILFunctionType originalFnTy, SILFunction *originalFn,
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
SILValue original, DifferentiationInvoker invoker);
/// Resolves differentiability witness for `originalFn`. Either by looking up
/// for registered one with respect to posible superset of `desiredIndices`,
/// or creating a new private one. Returns `nullptr` on failure signifying
/// that a diagnostic has been emitted using `invoker`.
SILDifferentiabilityWitness *getOrCreateMinimalDifferentiabilitywitness(
CanSILFunctionType originalFnTy, SILFunction *originalFn,
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
SILValue original, DifferentiationInvoker invoker);
/// Emits a reference to a derivative function of `original`, differentiated /// Emits a reference to a derivative function of `original`, differentiated
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for /// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
/// the derivative function and the actual indices that the derivative /// the derivative function and the actual indices that the derivative
@@ -527,12 +546,142 @@ SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction(
return silFnIt->second; return silFnIt->second;
} }
SILDifferentiabilityWitness *
DifferentiationTransformer::createPrivateDifferentiabilityWitness(
CanSILFunctionType originalFnTy, SILFunction *originalFn,
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
SILValue original, DifferentiationInvoker invoker) {
// Check non-differentiable cases before creating a new private
// differentiability witness.
// If the function is intentionally marked as being opaque to
// differentiation, then we should not create a task for it.
if (originalFn->hasSemanticsAttr("autodiff.opaque")) {
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_opaque_function_not_differentiable);
return nullptr;
}
// Check and diagnose non-differentiable arguments.
for (auto [paramIndex, param] :
llvm::enumerate(originalFnTy->getParameters())) {
if (!desiredParameterIndices->contains(paramIndex))
continue;
SILType paramType = param.getSILStorageInterfaceType();
if (!paramType.isDifferentiable(context.getModule())) {
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_nondifferentiable_argument);
return nullptr;
}
}
// Check and diagnose non-differentiable results.
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
unsigned firstYieldResultIndex =
firstSemanticParamResultIdx +
originalFnTy->getNumAutoDiffSemanticResultsParameters();
for (auto resultIndex : desiredResultIndices->getIndices()) {
SILType resultType;
if (resultIndex >= firstYieldResultIndex) {
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
auto yield = originalFnTy->getYields()[yieldResultIndex];
// We can only differentiate indirect yields. This should be diagnosed
// earlier in VJPCloner.
assert(yield.isAutoDiffSemanticResult() && "unsupported result");
resultType = yield.getSILStorageInterfaceType();
} else if (resultIndex >= firstSemanticParamResultIdx) {
auto semanticResultParamIdx = resultIndex - firstSemanticParamResultIdx;
auto semanticResultParam = *std::next(
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
semanticResultParamIdx);
resultType = semanticResultParam.getSILStorageInterfaceType();
} else {
resultType =
originalFnTy->getResults()[resultIndex].getSILStorageInterfaceType();
}
if (!resultType || !resultType.isDifferentiable(context.getModule())) {
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_nondifferentiable_result);
return nullptr;
}
}
// Check and diagnose external declarations.
if (originalFn->isExternalDeclaration()) {
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_external_nondifferentiable_function);
return nullptr;
}
// Soundness check passed. Create a new differentiability witness
GenericSignature contextualDerivativeGenSig = GenericSignature();
if (invoker.getKind() ==
DifferentiationInvoker::Kind::IndirectDifferentiation)
contextualDerivativeGenSig = invoker.getIndirectDifferentiation()
.second->getDerivativeGenericSignature();
auto derivativeConstrainedGenSig =
autodiff::getConstrainedDerivativeGenericSignature(
originalFnTy, desiredParameterIndices, desiredResultIndices,
contextualDerivativeGenSig, LookUpConformanceInModule());
auto *witness = SILDifferentiabilityWitness::createDefinition(
context.getModule(), SILLinkage::Private, originalFn,
DifferentiabilityKind::Reverse, desiredParameterIndices,
desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr,
/*vjp*/ nullptr, /*isSerialized*/ false);
if (canonicalizeDifferentiabilityWitness(witness, invoker, IsNotSerialized))
return nullptr;
return witness;
}
SILDifferentiabilityWitness *
DifferentiationTransformer::getOrCreateMinimalDifferentiabilitywitness(
CanSILFunctionType originalFnTy, SILFunction *originalFn,
IndexSubset *desiredParameterIndices, IndexSubset *desiredResultIndices,
SILValue original, DifferentiationInvoker invoker) {
// NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has
// parameters corresponding to captured variables.
// TODO: If possible, change `autodiff::getLoweredParameterIndices` to
// take `CaptureInfo` into account.
if (originalFnTy->getNumParameters() >
desiredParameterIndices->getCapacity()) {
desiredParameterIndices = desiredParameterIndices->extendingCapacity(
context.getASTContext(), originalFnTy->getNumParameters());
}
// Look up a differentiability witness with the exact configuration.
auto *minimalWitness = getExactDifferentiabilityWitness(
context.getModule(), originalFn, desiredParameterIndices,
desiredResultIndices);
// Otherwise, look up a differentiability witness with a minimal superset
// configuration.
if (!minimalWitness)
minimalWitness = getOrCreateMinimalASTDifferentiabilityWitness(
context.getModule(), originalFn, DifferentiabilityKind::Reverse,
desiredParameterIndices, desiredResultIndices);
// Finally, try to create private differentiability witness
if (!minimalWitness)
minimalWitness = createPrivateDifferentiabilityWitness(
originalFnTy, originalFn, desiredParameterIndices, desiredResultIndices,
original, invoker);
return minimalWitness;
}
std::optional<std::pair<SILValue, AutoDiffConfig>> std::optional<std::pair<SILValue, AutoDiffConfig>>
DifferentiationTransformer::emitDerivativeFunctionReference( DifferentiationTransformer::emitDerivativeFunctionReference(
SILBuilder &builder, const AutoDiffConfig &desiredConfig, SILBuilder &builder, const AutoDiffConfig &desiredConfig,
AutoDiffDerivativeFunctionKind kind, SILValue original, AutoDiffDerivativeFunctionKind kind, SILValue original,
DifferentiationInvoker invoker, DifferentiationInvoker invoker,
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) { SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
SILFunction *parentFn = original->getFunction();
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind // If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
// matches the given kind and desired differentiation parameter indices, // matches the given kind and desired differentiation parameter indices,
// simply extract the derivative function of its function operand, retain the // simply extract the derivative function of its function operand, retain the
@@ -576,111 +725,16 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
auto loc = originalFRI->getLoc(); auto loc = originalFRI->getLoc();
auto *originalFn = originalFRI->getReferencedFunction(); auto *originalFn = originalFRI->getReferencedFunction();
auto originalFnTy = originalFn->getLoweredFunctionType(); auto originalFnTy = originalFn->getLoweredFunctionType();
auto *desiredParameterIndices = desiredConfig.parameterIndices;
auto *desiredResultIndices = desiredConfig.resultIndices; auto *minimalWitness = getOrCreateMinimalDifferentiabilitywitness(
// NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has originalFnTy, originalFn, desiredConfig.parameterIndices,
// parameters corresponding to captured variables. desiredConfig.resultIndices, original, invoker);
// TODO: If possible, change `autodiff::getLoweredParameterIndices` to
// take `CaptureInfo` into account. // All non-differentiability cases should be diagnosted
if (originalFnTy->getNumParameters() >
desiredParameterIndices->getCapacity()) {
desiredParameterIndices = desiredParameterIndices->extendingCapacity(
context.getASTContext(), originalFnTy->getNumParameters());
}
// Look up a differentiability witness with the exact configuration.
auto *minimalWitness = getExactDifferentiabilityWitness(
context.getModule(), originalFn, desiredParameterIndices,
desiredResultIndices);
// Otherwise, look up a differentiability witness with a minimal superset
// configuration.
if (!minimalWitness) if (!minimalWitness)
minimalWitness = getOrCreateMinimalASTDifferentiabilityWitness( return std::nullopt;
context.getModule(), originalFn, DifferentiabilityKind::Reverse,
desiredParameterIndices, desiredResultIndices); if (parentFn->isSerialized()) {
// If no minimal witness exists, check non-differentiable cases before
// creating a new private differentiability witness.
if (!minimalWitness) {
// If the function is intentionally marked as being opaque to
// differentiation, then we should not create a task for it.
if (originalFn->hasSemanticsAttr("autodiff.opaque")) {
context.emitNondifferentiabilityError(
original, invoker,
diag::autodiff_opaque_function_not_differentiable);
return std::nullopt;
}
// Check and diagnose non-differentiable arguments.
auto originalFnTy = originalFn->getLoweredFunctionType();
for (unsigned paramIndex : range(originalFnTy->getNumParameters())) {
if (desiredConfig.isWrtParameter(paramIndex) &&
!originalFnTy->getParameters()[paramIndex]
.getSILStorageInterfaceType()
.isDifferentiable(context.getModule())) {
auto diag = context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_nondifferentiable_argument);
return std::nullopt;
}
}
// Check and diagnose non-differentiable results.
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
unsigned firstYieldResultIndex = originalFnTy->getNumResults() +
originalFnTy->getNumAutoDiffSemanticResultsParameters();
for (auto resultIndex : desiredResultIndices->getIndices()) {
SILType resultType;
if (resultIndex >= firstYieldResultIndex) {
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
auto yield = originalFnTy->getYields()[yieldResultIndex];
// We can only differentiate indirect yields. This should be diagnosed
// earlier in VJPCloner.
assert(yield.isAutoDiffSemanticResult() && "unsupported result");
resultType = yield.getSILStorageInterfaceType();
} else if (resultIndex >= firstSemanticParamResultIdx) {
auto semanticResultParamIdx = resultIndex - firstSemanticParamResultIdx;
auto semanticResultParam =
*std::next(originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
semanticResultParamIdx);
resultType = semanticResultParam.getSILStorageInterfaceType();
} else {
resultType = originalFnTy->getResults()[resultIndex]
.getSILStorageInterfaceType();
}
if (!resultType || !resultType.isDifferentiable(context.getModule())) {
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_nondifferentiable_result);
return std::nullopt;
}
}
// Check and diagnose external declarations.
if (originalFn->isExternalDeclaration()) {
context.emitNondifferentiabilityError(
original, invoker,
diag::autodiff_external_nondifferentiable_function);
return std::nullopt;
}
// Soundness check passed. Create a new differentiability witness and
// canonicalize it.
GenericSignature contextualDerivativeGenSig = GenericSignature();
if (invoker.getKind() ==
DifferentiationInvoker::Kind::IndirectDifferentiation)
contextualDerivativeGenSig =
invoker.getIndirectDifferentiation()
.second->getDerivativeGenericSignature();
auto derivativeConstrainedGenSig =
autodiff::getConstrainedDerivativeGenericSignature(
originalFn->getLoweredFunctionType(),
desiredParameterIndices, desiredResultIndices,
contextualDerivativeGenSig,
LookUpConformanceInModule());
minimalWitness = SILDifferentiabilityWitness::createDefinition(
context.getModule(), SILLinkage::Private, originalFn,
DifferentiabilityKind::Reverse, desiredParameterIndices,
desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr,
/*vjp*/ nullptr, /*isSerialized*/ false);
if (canonicalizeDifferentiabilityWitness(minimalWitness, invoker,
IsNotSerialized))
return std::nullopt;
}
assert(minimalWitness);
if (original->getFunction()->isSerialized()) {
bool isWitnessPublic; bool isWitnessPublic;
if (SILFunction *unwrappedFn = if (SILFunction *unwrappedFn =
getUnwrappedCurryThunkFunction(originalFn)) { getUnwrappedCurryThunkFunction(originalFn)) {
@@ -691,15 +745,12 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
// private linkage. // private linkage.
isWitnessPublic = hasPublicVisibility(unwrappedFn->getLinkage()); isWitnessPublic = hasPublicVisibility(unwrappedFn->getLinkage());
} else if (originalFn->getDeclRef().getAbstractClosureExpr() && } else if (originalFn->getDeclRef().getAbstractClosureExpr() &&
originalFRI->getFunction() parentFn->getDeclRef().isDefaultArgGenerator()) {
->getDeclRef()
.isDefaultArgGenerator()) {
// If we reference a closure from inside default argument generator, // If we reference a closure from inside default argument generator,
// check against generator's visibility. If the function having this // check against generator's visibility. If the function having this
// default argument has public visibility, it's OK to have a closure // default argument has public visibility, it's OK to have a closure
// (which always has private visibility) as its default value. // (which always has private visibility) as its default value.
isWitnessPublic = isWitnessPublic = hasPublicVisibility(parentFn->getLinkage());
hasPublicVisibility(originalFRI->getFunction()->getLinkage());
} else { } else {
isWitnessPublic = hasPublicVisibility(minimalWitness->getLinkage()); isWitnessPublic = hasPublicVisibility(minimalWitness->getLinkage());
} }
@@ -709,16 +760,16 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
// FIXME: This is not a very robust way of determining if the function // FIXME: This is not a very robust way of determining if the function
// is a default argument. Also, we have not exhaustively listed all the // is a default argument. Also, we have not exhaustively listed all the
// kinds of fragility. // kinds of fragility.
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI) if (parentFn->getLinkage() == SILLinkage::PublicNonABI)
fragileKind = DefaultArgument; fragileKind = DefaultArgument;
context.emitNondifferentiabilityError( context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_private_derivative_from_fragile, original, invoker, diag::autodiff_private_derivative_from_fragile,
fragileKind, fragileKind,
isa_and_nonnull<AbstractClosureExpr>( isa_and_nonnull<AbstractClosureExpr>(loc.getAsASTNode<Expr>()));
originalFRI->getLoc().getAsASTNode<Expr>()));
return std::nullopt; return std::nullopt;
} }
} }
// TODO(TF-482): Move generic requirement checking logic to // TODO(TF-482): Move generic requirement checking logic to
// `getExactDifferentiabilityWitness` and // `getExactDifferentiabilityWitness` and
// `getOrCreateMinimalASTDifferentiabilityWitness`. // `getOrCreateMinimalASTDifferentiabilityWitness`.
@@ -726,18 +777,17 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
// By default, use the forwarding substitution map of the original function. // By default, use the forwarding substitution map of the original function.
// If the original callee is a `partial_apply` or `apply` instruction, use // If the original callee is a `partial_apply` or `apply` instruction, use
// its substitution map instead. // its substitution map instead.
auto substMap = original->getFunction()->getForwardingSubstitutionMap(); auto substMap = parentFn->getForwardingSubstitutionMap();
if (auto *pai = if (auto *pai = peerThroughFunctionConversions<PartialApplyInst>(original))
peerThroughFunctionConversions<PartialApplyInst>(original)) {
substMap = pai->getSubstitutionMap(); substMap = pai->getSubstitutionMap();
} else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original)) { else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original))
substMap = ai->getSubstitutionMap(); substMap = ai->getSubstitutionMap();
}
if (diagnoseUnsatisfiedRequirements( if (diagnoseUnsatisfiedRequirements(
context, original->getType().castTo<SILFunctionType>(), context, original->getType().castTo<SILFunctionType>(),
minimalWitness->getDerivativeGenericSignature(), substMap, invoker, minimalWitness->getDerivativeGenericSignature(), substMap, invoker,
original.getLoc().getSourceLoc())) original.getLoc().getSourceLoc()))
return std::nullopt; return std::nullopt;
DifferentiabilityWitnessFunctionKind witnessKind; DifferentiabilityWitnessFunctionKind witnessKind;
switch (kind) { switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP: case AutoDiffDerivativeFunctionKind::JVP:
@@ -764,6 +814,79 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
if (auto *witnessMethod = if (auto *witnessMethod =
peerThroughFunctionConversions<WitnessMethodInst>(original)) { peerThroughFunctionConversions<WitnessMethodInst>(original)) {
auto loc = witnessMethod->getLoc(); auto loc = witnessMethod->getLoc();
// See if we can derive derivatives for a method statically
auto [method, table] = lookUpFunctionInWitnessTable(
witnessMethod, SILModule::LinkingMode::LinkNormal);
if (method) {
auto *originalFn = method;
auto originalFnTy = method->getLoweredFunctionTypeInContext(
TypeExpansionContext(*original->getFunction()));
auto *minimalWitness = getOrCreateMinimalDifferentiabilitywitness(
originalFnTy, originalFn, desiredConfig.parameterIndices,
desiredConfig.resultIndices, original, invoker);
// All non-differentiability cases should be diagnosted
if (!minimalWitness)
return std::nullopt;
if (parentFn->isSerialized() &&
!hasPublicVisibility(minimalWitness->getLinkage())) {
enum { Inlinable = 0, DefaultArgument = 1 };
unsigned fragileKind = Inlinable;
// FIXME: This is not a very robust way of determining if the function
// is a default argument. Also, we have not exhaustively listed all the
// kinds of fragility.
if (parentFn->getLinkage() == SILLinkage::PublicNonABI)
fragileKind = DefaultArgument;
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_private_derivative_from_fragile,
fragileKind,
isa_and_nonnull<AbstractClosureExpr>(loc.getAsASTNode<Expr>()));
return std::nullopt;
}
auto substMap = parentFn->getForwardingSubstitutionMap();
if (auto *pai =
peerThroughFunctionConversions<PartialApplyInst>(original))
substMap =
getWitnessMethodSubstitutions(context.getModule(), pai, originalFn,
witnessMethod->getConformance());
else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original))
substMap =
getWitnessMethodSubstitutions(context.getModule(), ai, originalFn,
witnessMethod->getConformance());
if (diagnoseUnsatisfiedRequirements(
context, original->getType().castTo<SILFunctionType>(),
minimalWitness->getDerivativeGenericSignature(), substMap,
invoker, original.getLoc().getSourceLoc()))
return std::nullopt;
DifferentiabilityWitnessFunctionKind witnessKind;
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP:
witnessKind = DifferentiabilityWitnessFunctionKind::JVP;
break;
case AutoDiffDerivativeFunctionKind::VJP:
witnessKind = DifferentiabilityWitnessFunctionKind::VJP;
break;
}
auto *derivativeFnRef = builder.createDifferentiabilityWitnessFunction(
loc, witnessKind, minimalWitness);
auto convertedRef = reapplyFunctionConversion(
context, derivativeFnRef, witnessMethod, original, builder, loc,
newBuffersToDealloc, desiredConfig.parameterIndices,
desiredConfig.resultIndices,
derivativeFnRef->getType()
.getASTType()
->castTo<SILFunctionType>()
->getSubstGenericSignature());
return std::make_pair(convertedRef, minimalWitness->getConfig());
}
auto requirementDeclRef = witnessMethod->getMember(); auto requirementDeclRef = witnessMethod->getMember();
auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl(); auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl();
// If requirement declaration does not have any derivative function // If requirement declaration does not have any derivative function

View File

@@ -588,66 +588,6 @@ namespace {
} }
} }
// Returns None if the AST does not contain enough information to recover
// substitutions; this is different from an Optional(SubstitutionMap()),
// indicating a valid call to a non-generic operator.
std::optional<SubstitutionMap> getOperatorSubstitutions(ValueDecl *witness,
Type refType) {
// We have to recover substitutions in this hacky way because
// the AST does not retain enough information to devirtualize
// calls like this.
auto witnessType = witness->getInterfaceType();
// Compute the substitutions.
auto *gft = witnessType->getAs<GenericFunctionType>();
if (gft == nullptr) {
if (refType->isEqual(witnessType))
return SubstitutionMap();
return std::nullopt;
}
auto sig = gft->getGenericSignature();
auto *env = sig.getGenericEnvironment();
witnessType = FunctionType::get(gft->getParams(),
gft->getResult(),
gft->getExtInfo());
witnessType = env->mapTypeIntoContext(witnessType);
TypeSubstitutionMap subs;
auto substType = witnessType->substituteBindingsTo(
refType,
[&](ArchetypeType *origType, CanType substType) -> CanType {
if (auto gpType = dyn_cast<GenericTypeParamType>(
origType->getInterfaceType()->getCanonicalType()))
subs[gpType] = substType;
return substType;
});
// If substitution failed, it means that the protocol requirement type
// and the witness type did not match up. The only time that this
// should happen is when the witness is defined in a base class and
// the actual call uses a derived class. For example,
//
// protocol P { func +(lhs: Self, rhs: Self) }
// class Base : P { func +(lhs: Base, rhs: Base) {} }
// class Derived : Base {}
//
// If we enter this code path with two operands of type Derived,
// we know we're calling the protocol requirement P.+, with a
// substituted type of (Derived, Derived) -> (). But the type of
// the witness is (Base, Base) -> (). Just bail out and make a
// witness method call in this rare case; SIL mandatory optimizations
// will likely devirtualize it anyway.
if (!substType)
return std::nullopt;
return SubstitutionMap::get(sig,
QueryTypeSubstitutionMap{subs},
LookUpConformanceInModule());
}
/// Determine whether the given reference is to a method on /// Determine whether the given reference is to a method on
/// a remote distributed actor in the given context. /// a remote distributed actor in the given context.
bool isDistributedThunk(ConcreteDeclRef ref, Expr *context); bool isDistributedThunk(ConcreteDeclRef ref, Expr *context);
@@ -674,65 +614,6 @@ namespace {
auto baseTy = getBaseType(adjustedFullType->castTo<FunctionType>()); auto baseTy = getBaseType(adjustedFullType->castTo<FunctionType>());
// Handle operator requirements found in protocols.
if (auto proto = dyn_cast<ProtocolDecl>(decl->getDeclContext())) {
bool isCurried = shouldBuildCurryThunk(choice, /*baseIsInstance=*/false);
// If we have a concrete conformance, build a call to the witness.
//
// FIXME: This is awful. We should be able to handle this as a call to
// the protocol requirement with Self == the concrete type, and SILGen
// (or later) can devirtualize as appropriate.
auto conformance = checkConformance(baseTy, proto);
if (conformance.isConcrete()) {
if (auto witness = conformance.getConcrete()->getWitnessDecl(decl)) {
bool isMemberOperator = witness->getDeclContext()->isTypeContext();
if (!isMemberOperator || !isCurried) {
// The fullType was computed by substituting the protocol
// requirement so it always has a (Self) -> ... curried
// application. Strip it off if the witness was a top-level
// function.
Type refType;
if (isMemberOperator)
refType = adjustedFullType;
else
refType = adjustedFullType->castTo<AnyFunctionType>()->getResult();
// Build the AST for the call to the witness.
auto subMap = getOperatorSubstitutions(witness, refType);
if (subMap) {
ConcreteDeclRef witnessRef(witness, *subMap);
auto declRefExpr = new (ctx) DeclRefExpr(witnessRef, loc,
/*Implicit=*/false);
declRefExpr->setFunctionRefInfo(choice.getFunctionRefInfo());
cs.setType(declRefExpr, refType);
Expr *refExpr;
if (isMemberOperator) {
// If the operator is a type member, add the implicit
// (Self) -> ... call.
Expr *base =
TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy,
ctx);
cs.setType(base, MetatypeType::get(baseTy));
refExpr =
DotSyntaxCallExpr::create(ctx, declRefExpr, SourceLoc(),
Argument::unlabeled(base));
auto refType = adjustedFullType->castTo<FunctionType>()->getResult();
cs.setType(refExpr, refType);
} else {
refExpr = declRefExpr;
}
return forceUnwrapIfExpected(refExpr, locator);
}
}
}
}
}
// Build a reference to the member. // Build a reference to the member.
Expr *base = Expr *base =
TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy, ctx); TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy, ctx);

View File

@@ -2,6 +2,9 @@
// REQUIRES: executable_test // REQUIRES: executable_test
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext. // Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
We cannot use `SIMD` :( */
// XFAIL: *
import _Differentiation import _Differentiation
import StdlibUnittest import StdlibUnittest

View File

@@ -2,6 +2,9 @@
// NOTE: Verify whether forward-mode differentiation crashes. It currently does. // NOTE: Verify whether forward-mode differentiation crashes. It currently does.
// RUN: not --crash %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil %s // RUN: not --crash %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil %s
// REQUIRES: executable_test // REQUIRES: executable_test
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
We cannot use `Tracked<T>` :( */
// XFAIL: *
import StdlibUnittest import StdlibUnittest
import DifferentiationUnittest import DifferentiationUnittest

View File

@@ -9,6 +9,144 @@ import DifferentiationUnittest
var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty") var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty")
struct TangentSpace : AdditiveArithmetic {
let x, y: Float
}
extension TangentSpace : Differentiable {
typealias TangentVector = TangentSpace
}
struct Space {
/// `x` is a computed property with a custom vjp.
var x: Float {
@differentiable(reverse)
get { storedX }
set { storedX = newValue }
}
@derivative(of: x)
func vjpX() -> (value: Float, pullback: (Float) -> TangentSpace) {
return (x, { v in TangentSpace(x: v, y: 0) } )
}
private var storedX: Float
@differentiable(reverse)
var y: Float
init(x: Float, y: Float) {
self.storedX = x
self.y = y
}
}
extension Space : Differentiable {
typealias TangentVector = TangentSpace
mutating func move(by offset: TangentSpace) {
x.move(by: offset.x)
y.move(by: offset.y)
}
}
E2EDifferentiablePropertyTests.test("computed property") {
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
return 2 * point.x
}
let expectedGrad = TangentSpace(x: 2, y: 0)
expectEqual(expectedGrad, actualGrad)
}
E2EDifferentiablePropertyTests.test("stored property") {
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
return 3 * point.y
}
let expectedGrad = TangentSpace(x: 0, y: 3)
expectEqual(expectedGrad, actualGrad)
}
struct GenericMemberWrapper<T : Differentiable> : Differentiable {
// Stored property.
@differentiable(reverse)
var x: T
func vjpX() -> (T, (T.TangentVector) -> GenericMemberWrapper.TangentVector) {
return (x, { TangentVector(x: $0) })
}
}
E2EDifferentiablePropertyTests.test("generic stored property") {
let actualGrad = gradient(at: GenericMemberWrapper<Float>(x: 1)) { point in
return 2 * point.x
}
let expectedGrad = GenericMemberWrapper<Float>.TangentVector(x: 2)
expectEqual(expectedGrad, actualGrad)
}
struct ProductSpaceSelfTangent : AdditiveArithmetic {
let x, y: Float
}
extension ProductSpaceSelfTangent : Differentiable {
typealias TangentVector = ProductSpaceSelfTangent
}
E2EDifferentiablePropertyTests.test("fieldwise product space, self tangent") {
let actualGrad = gradient(at: ProductSpaceSelfTangent(x: 0, y: 0)) { (point: ProductSpaceSelfTangent) -> Float in
return 5 * point.y
}
let expectedGrad = ProductSpaceSelfTangent(x: 0, y: 5)
expectEqual(expectedGrad, actualGrad)
}
struct ProductSpaceOtherTangentTangentSpace : AdditiveArithmetic {
let x, y: Float
}
extension ProductSpaceOtherTangentTangentSpace : Differentiable {
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
}
struct ProductSpaceOtherTangent {
var x, y: Float
}
extension ProductSpaceOtherTangent : Differentiable {
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
mutating func move(by offset: ProductSpaceOtherTangentTangentSpace) {
x.move(by: offset.x)
y.move(by: offset.y)
}
}
E2EDifferentiablePropertyTests.test("fieldwise product space, other tangent") {
let actualGrad = gradient(
at: ProductSpaceOtherTangent(x: 0, y: 0)
) { (point: ProductSpaceOtherTangent) -> Float in
return 7 * point.y
}
let expectedGrad = ProductSpaceOtherTangentTangentSpace(x: 0, y: 7)
expectEqual(expectedGrad, actualGrad)
}
E2EDifferentiablePropertyTests.test("computed property") {
struct TF_544 : Differentiable {
var value: Float
@differentiable(reverse)
var computed: Float {
get { value }
set { value = newValue }
}
}
let actualGrad = gradient(at: TF_544(value: 2.4)) { x in
return x.computed * x.computed
}
let expectedGrad = TF_544.TangentVector(value: 4.8)
expectEqual(expectedGrad, actualGrad)
}
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
We cannot use `Tracked<T>` :(
struct TangentSpace : AdditiveArithmetic { struct TangentSpace : AdditiveArithmetic {
let x, y: Tracked<Float> let x, y: Tracked<Float>
} }
@@ -144,5 +282,6 @@ E2EDifferentiablePropertyTests.testWithLeakChecking("computed property") {
let expectedGrad = TF_544.TangentVector(value: 4.8) let expectedGrad = TF_544.TangentVector(value: 4.8)
expectEqual(expectedGrad, actualGrad) expectEqual(expectedGrad, actualGrad)
} }
*/
runAllTests() runAllTests()

View File

@@ -8,19 +8,39 @@ var ExistentialTests = TestSuite("Existential")
protocol A { protocol A {
@differentiable(reverse, wrt: x) @differentiable(reverse, wrt: x)
func a(_ x: Tracked<Float>) -> Tracked<Float> func a(_ x: Float) -> Float
} }
func b(g: A) -> Tracked<Float> { func b(g: A) -> Float {
return gradient(at: 3) { x in g.a(x) } return gradient(at: 3) { x in g.a(x) }
} }
struct B : A { struct B : A {
@differentiable(reverse, wrt: x) @differentiable(reverse, wrt: x)
func a(_ x: Tracked<Float>) -> Tracked<Float> { return x * 5 } func a(_ x: Float) -> Float { return x * 5 }
} }
ExistentialTests.testWithLeakChecking("Existential method VJP") { ExistentialTests.test("Existential method VJP-Tracked") {
expectEqual(5.0, b(g: B())) expectEqual(5.0, b(g: B()))
} }
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
We cannot use `Tracked<T>` :(
protocol ATracked {
@differentiable(reverse, wrt: x)
func a(_ x: Tracked<Float>) -> Tracked<Float>
}
func b(g: ATracked) -> Tracked<Float> {
return gradient(at: 3) { x in g.a(x) }
}
struct BTracked : ATracked {
@differentiable(reverse, wrt: x)
func a(_ x: Tracked<Float>) -> Tracked<Float> { return x * 5 }
}
ExistentialTests.testWithLeakChecking("Existential method VJP-Tracked") {
expectEqual(5.0, b(g: BTracked()))
}
*/
runAllTests() runAllTests()

View File

@@ -1,5 +1,8 @@
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation) // RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
// REQUIRES: executable_test // REQUIRES: executable_test
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
We cannot use `SIMD` :( */
// XFAIL: *
import StdlibUnittest import StdlibUnittest
import DifferentiationUnittest import DifferentiationUnittest

View File

@@ -1,5 +1,8 @@
// RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation) // RUN: %target-run-simple-swift(-Xfrontend -enable-experimental-forward-mode-differentiation)
// REQUIRES: executable_test // REQUIRES: executable_test
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
We cannot use `Tracked<T>` :( */
// XFAIL: *
import StdlibUnittest import StdlibUnittest
import DifferentiationUnittest import DifferentiationUnittest

View File

@@ -9,50 +9,50 @@ var MethodTests = TestSuite("Method")
// ==== Tests with generated adjoint ==== // ==== Tests with generated adjoint ====
struct Parameter : Equatable { struct Parameter : Equatable {
private let storedX: Tracked<Float> private let storedX: Float
@differentiable(reverse, wrt: (self)) @differentiable(reverse, wrt: (self))
var x: Tracked<Float> { var x: Float {
return storedX return storedX
} }
init(x: Tracked<Float>) { init(x: Float) {
storedX = x storedX = x
} }
@derivative(of: x) @derivative(of: x)
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Parameter) { func vjpX() -> (value: Float, pullback: (Float) -> Parameter) {
return (x, { dx in Parameter(x: dx) } ) return (x, { dx in Parameter(x: dx) } )
} }
@derivative(of: x) @derivative(of: x)
func jvpX() -> (value: Tracked<Float>, differential: (Parameter) -> Tracked<Float>) { func jvpX() -> (value: Float, differential: (Parameter) -> Float) {
return (x, { $0.x }) return (x, { $0.x })
} }
} }
extension Parameter { extension Parameter {
func squared() -> Tracked<Float> { func squared() -> Float {
return x * x return x * x
} }
static func squared(p: Parameter) -> Tracked<Float> { static func squared(p: Parameter) -> Float {
return p.x * p.x return p.x * p.x
} }
func multiplied(with other: Tracked<Float>) -> Tracked<Float> { func multiplied(with other: Float) -> Float {
return x * other return x * other
} }
static func * (_ a: Parameter, _ b: Parameter) -> Tracked<Float> { static func * (_ a: Parameter, _ b: Parameter) -> Float {
return a.x * b.x return a.x * b.x
} }
} }
extension Parameter : Differentiable, AdditiveArithmetic { extension Parameter : Differentiable, AdditiveArithmetic {
typealias TangentVector = Parameter typealias TangentVector = Parameter
typealias Scalar = Tracked<Float> typealias Scalar = Float
typealias Shape = () typealias Shape = ()
init(repeating repeatedValue: Tracked<Float>, shape: ()) { init(repeating repeatedValue: Float, shape: ()) {
self.init(x: repeatedValue) self.init(x: repeatedValue)
} }
static func + (lhs: Parameter, rhs: Parameter) -> Parameter { static func + (lhs: Parameter, rhs: Parameter) -> Parameter {
@@ -67,37 +67,185 @@ extension Parameter : Differentiable, AdditiveArithmetic {
static var zero: Parameter { return Parameter(x: 0) } static var zero: Parameter { return Parameter(x: 0) }
} }
MethodTests.testWithLeakChecking( MethodTests.test(
"instance method with generated adjoint, called from differentiated func" "instance method with generated adjoint, called from differentiated func"
) { ) {
func f(_ p: Parameter) -> Tracked<Float> { func f(_ p: Parameter) -> Float {
return 100 * p.squared() return 100 * p.squared()
} }
expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f)) expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f))
expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f)) expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f))
} }
MethodTests.test(
"instance method with generated adjoint, differentiated directly"
) {
// This is our current syntax for taking gradients of instance methods
// directly. If/when we develop nicer syntax for this, change this test.
func g(p: Parameter) -> Float { p.squared() }
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: g))
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: g))
}
MethodTests.test("instance method with generated adjoint, wrt only self") {
func f(_ p: Parameter) -> Float {
return 100 * p.multiplied(with: 200)
}
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
}
MethodTests.test("instance method with generated adjoint, wrt only non-self") {
func f(_ other: Float) -> Float {
return 100 * Parameter(x: 200).multiplied(with: other)
}
expectEqual(100 * 200, gradient(at: 1, of: f))
expectEqual(100 * 200, gradient(at: 2, of: f))
}
MethodTests.test(
"instance method with generated adjoint, wrt self and non-self"
) {
expectEqual(
(Parameter(x: 100), 200), gradient(at: Parameter(x: 200), 100) { $0.multiplied(with: $1) })
expectEqual(
(Parameter(x: 200), 100), gradient(at: Parameter(x: 100), 200) { $0.multiplied(with: $1) })
}
MethodTests.test(
"static method with generated adjoint, called from differentiated func"
) {
func f(_ p: Parameter) -> Float {
return 100 * Parameter.squared(p: p)
}
expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f))
expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f))
}
MethodTests.test(
"static method with generated adjoint, differentiated directly"
) {
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: Parameter.squared))
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: Parameter.squared))
}
MethodTests.test("static method with generated adjoint, wrt only first param") {
func f(_ p: Parameter) -> Float {
return 100 * (p * Parameter(x: 200))
}
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
}
MethodTests.test("static method with generated adjoint, wrt only second param") {
func f(_ p: Parameter) -> Float {
return 100 * (Parameter(x: 200) * p)
}
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f))
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f))
}
MethodTests.test("static method with generated adjoint, wrt all params") {
func g(a: Parameter, b: Parameter) -> Float { a * b }
expectEqual((Parameter(x: 100), Parameter(x: 200)),
gradient(at: Parameter(x: 200), Parameter(x: 100), of: g))
expectEqual((Parameter(x: 200), Parameter(x: 100)),
gradient(at: Parameter(x: 100), Parameter(x: 200), of: g))
}
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
We cannot use `Tracked<T>` :(
struct ParameterTracked : Equatable {
private let storedX: Tracked<Float>
@differentiable(reverse, wrt: (self))
var x: Tracked<Float> {
return storedX
}
init(x: Tracked<Float>) {
storedX = x
}
@derivative(of: x)
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> ParameterTracked) {
return (x, { dx in ParameterTracked(x: dx) } )
}
@derivative(of: x)
func jvpX() -> (value: Tracked<Float>, differential: (ParameterTracked) -> Tracked<Float>) {
return (x, { $0.x })
}
}
extension ParameterTracked {
func squared() -> Tracked<Float> {
return x * x
}
static func squared(p: ParameterTracked) -> Tracked<Float> {
return p.x * p.x
}
func multiplied(with other: Tracked<Float>) -> Tracked<Float> {
return x * other
}
static func * (_ a: ParameterTracked, _ b: ParameterTracked) -> Tracked<Float> {
return a.x * b.x
}
}
extension ParameterTracked : Differentiable, AdditiveArithmetic {
typealias TangentVector = ParameterTracked
typealias Scalar = Tracked<Float>
typealias Shape = ()
init(repeating repeatedValue: Tracked<Float>, shape: ()) {
self.init(x: repeatedValue)
}
static func + (lhs: ParameterTracked, rhs: ParameterTracked) -> ParameterTracked {
return ParameterTracked(x: lhs.x + rhs.x)
}
static func - (lhs: ParameterTracked, rhs: ParameterTracked) -> ParameterTracked {
return ParameterTracked(x: lhs.x - rhs.x)
}
static func * (lhs: Scalar, rhs: ParameterTracked) -> ParameterTracked {
return ParameterTracked(x: lhs * rhs.x)
}
static var zero: ParameterTracked { return ParameterTracked(x: 0) }
}
MethodTests.testWithLeakChecking(
"instance method with generated adjoint, called from differentiated func"
) {
func f(_ p: ParameterTracked) -> Tracked<Float> {
return 100 * p.squared()
}
expectEqual(ParameterTracked(x: 4 * 100), gradient(at: ParameterTracked(x: 2), of: f))
expectEqual(ParameterTracked(x: 40 * 100), gradient(at: ParameterTracked(x: 20), of: f))
}
MethodTests.testWithLeakChecking( MethodTests.testWithLeakChecking(
"instance method with generated adjoint, differentiated directly" "instance method with generated adjoint, differentiated directly"
) { ) {
// This is our current syntax for taking gradients of instance methods // This is our current syntax for taking gradients of instance methods
// directly. If/when we develop nicer syntax for this, change this test. // directly. If/when we develop nicer syntax for this, change this test.
func g(p: Parameter) -> Tracked<Float> { p.squared() } func g(p: ParameterTracked) -> Tracked<Float> { p.squared() }
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: g)) expectEqual(ParameterTracked(x: 4), gradient(at: ParameterTracked(x: 2), of: g))
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: g)) expectEqual(ParameterTracked(x: 40), gradient(at: ParameterTracked(x: 20), of: g))
} }
MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only self") { MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only self") {
func f(_ p: Parameter) -> Tracked<Float> { func f(_ p: ParameterTracked) -> Tracked<Float> {
return 100 * p.multiplied(with: 200) return 100 * p.multiplied(with: 200)
} }
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f)) expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f))
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f)) expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f))
} }
MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only non-self") { MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only non-self") {
func f(_ other: Tracked<Float>) -> Tracked<Float> { func f(_ other: Tracked<Float>) -> Tracked<Float> {
return 100 * Parameter(x: 200).multiplied(with: other) return 100 * ParameterTracked(x: 200).multiplied(with: other)
} }
expectEqual(100 * 200, gradient(at: 1, of: f)) expectEqual(100 * 200, gradient(at: 1, of: f))
expectEqual(100 * 200, gradient(at: 2, of: f)) expectEqual(100 * 200, gradient(at: 2, of: f))
@@ -107,51 +255,52 @@ MethodTests.testWithLeakChecking(
"instance method with generated adjoint, wrt self and non-self" "instance method with generated adjoint, wrt self and non-self"
) { ) {
expectEqual( expectEqual(
(Parameter(x: 100), 200), gradient(at: Parameter(x: 200), 100) { $0.multiplied(with: $1) }) (ParameterTracked(x: 100), 200), gradient(at: ParameterTracked(x: 200), 100) { $0.multiplied(with: $1) })
expectEqual( expectEqual(
(Parameter(x: 200), 100), gradient(at: Parameter(x: 100), 200) { $0.multiplied(with: $1) }) (ParameterTracked(x: 200), 100), gradient(at: ParameterTracked(x: 100), 200) { $0.multiplied(with: $1) })
} }
MethodTests.testWithLeakChecking( MethodTests.testWithLeakChecking(
"static method with generated adjoint, called from differentiated func" "static method with generated adjoint, called from differentiated func"
) { ) {
func f(_ p: Parameter) -> Tracked<Float> { func f(_ p: ParameterTracked) -> Tracked<Float> {
return 100 * Parameter.squared(p: p) return 100 * ParameterTracked.squared(p: p)
} }
expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f)) expectEqual(ParameterTracked(x: 4 * 100), gradient(at: ParameterTracked(x: 2), of: f))
expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f)) expectEqual(ParameterTracked(x: 40 * 100), gradient(at: ParameterTracked(x: 20), of: f))
} }
MethodTests.testWithLeakChecking( MethodTests.testWithLeakChecking(
"static method with generated adjoint, differentiated directly" "static method with generated adjoint, differentiated directly"
) { ) {
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: Parameter.squared)) expectEqual(ParameterTracked(x: 4), gradient(at: ParameterTracked(x: 2), of: ParameterTracked.squared))
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: Parameter.squared)) expectEqual(ParameterTracked(x: 40), gradient(at: ParameterTracked(x: 20), of: ParameterTracked.squared))
} }
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only first param") { MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only first param") {
func f(_ p: Parameter) -> Tracked<Float> { func f(_ p: ParameterTracked) -> Tracked<Float> {
return 100 * (p * Parameter(x: 200)) return 100 * (p * ParameterTracked(x: 200))
} }
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f)) expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f))
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f)) expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f))
} }
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only second param") { MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only second param") {
func f(_ p: Parameter) -> Tracked<Float> { func f(_ p: ParameterTracked) -> Tracked<Float> {
return 100 * (Parameter(x: 200) * p) return 100 * (ParameterTracked(x: 200) * p)
} }
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f)) expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f))
expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f)) expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f))
} }
MethodTests.testWithLeakChecking("static method with generated adjoint, wrt all params") { MethodTests.testWithLeakChecking("static method with generated adjoint, wrt all params") {
func g(a: Parameter, b: Parameter) -> Tracked<Float> { a * b } func g(a: ParameterTracked, b: ParameterTracked) -> Tracked<Float> { a * b }
expectEqual((Parameter(x: 100), Parameter(x: 200)), expectEqual((ParameterTracked(x: 100), ParameterTracked(x: 200)),
gradient(at: Parameter(x: 200), Parameter(x: 100), of: g)) gradient(at: ParameterTracked(x: 200), ParameterTracked(x: 100), of: g))
expectEqual((Parameter(x: 200), Parameter(x: 100)), expectEqual((ParameterTracked(x: 200), ParameterTracked(x: 100)),
gradient(at: Parameter(x: 100), Parameter(x: 200), of: g)) gradient(at: ParameterTracked(x: 100), ParameterTracked(x: 200), of: g))
} }
*/
// ==== Tests with custom adjoint ==== // ==== Tests with custom adjoint ====
@@ -174,27 +323,27 @@ struct DiffWrtSelf : Differentiable {
} }
struct CustomParameter : Equatable { struct CustomParameter : Equatable {
let storedX: Tracked<Float> let storedX: Float
@differentiable(reverse, wrt: (self)) @differentiable(reverse, wrt: (self))
var x: Tracked<Float> { var x: Float {
return storedX return storedX
} }
init(x: Tracked<Float>) { init(x: Float) {
storedX = x storedX = x
} }
@derivative(of: x) @derivative(of: x)
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) { func vjpX() -> (value: Float, pullback: (Float) -> CustomParameter) {
return (x, { dx in CustomParameter(x: dx) }) return (x, { dx in CustomParameter(x: dx) })
} }
} }
extension CustomParameter : Differentiable, AdditiveArithmetic { extension CustomParameter : Differentiable, AdditiveArithmetic {
typealias TangentVector = CustomParameter typealias TangentVector = CustomParameter
typealias Scalar = Tracked<Float> typealias Scalar = Float
typealias Shape = () typealias Shape = ()
init(repeating repeatedValue: Tracked<Float>, shape: ()) { init(repeating repeatedValue: Float, shape: ()) {
self.init(x: repeatedValue) self.init(x: repeatedValue)
} }
static func + (lhs: CustomParameter, rhs: CustomParameter) -> CustomParameter { static func + (lhs: CustomParameter, rhs: CustomParameter) -> CustomParameter {
@@ -209,38 +358,256 @@ extension CustomParameter : Differentiable, AdditiveArithmetic {
static var zero: CustomParameter { return CustomParameter(x: 0) } static var zero: CustomParameter { return CustomParameter(x: 0) }
} }
extension Tracked where T : FloatingPoint { extension Float {
func clamped(to limits: ClosedRange<Tracked<T>>) -> Tracked<T> { func clamped(to limits: ClosedRange<Float>) -> Float {
return min(max(self, limits.lowerBound), limits.upperBound) return min(max(self, limits.lowerBound), limits.upperBound)
} }
} }
extension CustomParameter { extension CustomParameter {
@differentiable(reverse, wrt: (self)) @differentiable(reverse, wrt: (self))
func squared() -> Tracked<Float> { func squared() -> Float {
return x * x return x * x
} }
@derivative(of: squared) @derivative(of: squared)
func dSquared() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) { func dSquared() -> (value: Float, pullback: (Float) -> CustomParameter) {
return (squared(), { [x] v in CustomParameter(x: (2 * x).clamped(to: -10.0...10.0) * v) }) return (squared(), { [x] v in CustomParameter(x: (2 * x).clamped(to: -10.0...10.0) * v) })
} }
@differentiable(reverse) @differentiable(reverse)
static func squared(p: CustomParameter) -> Tracked<Float> { static func squared(p: CustomParameter) -> Float {
return p.x * p.x return p.x * p.x
} }
@derivative(of: squared) @derivative(of: squared)
static func dSquared( static func dSquared(
_ p: CustomParameter _ p: CustomParameter
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) { ) -> (value: Float, pullback: (Float) -> CustomParameter) {
return (p.x * p.x, { v in CustomParameter(x: (2 * p.x).clamped(to: -10.0...10.0) * v) }) return (p.x * p.x, { v in CustomParameter(x: (2 * p.x).clamped(to: -10.0...10.0) * v) })
} }
// There is currently no way to define multiple custom VJPs wrt different // There is currently no way to define multiple custom VJPs wrt different
// parameters on the same func, so we define a copy of this func per adjoint. // parameters on the same func, so we define a copy of this func per adjoint.
@differentiable(reverse, wrt: (self, other))
func multiplied(with other: Float) -> Float {
return x * other
}
@differentiable(reverse, wrt: (other))
func multiplied_constSelf(with other: Float) -> Float {
return x * other
}
@differentiable(reverse, wrt: (self))
func multiplied_constOther(with other: Float) -> Float {
return x * other
}
@derivative(of: multiplied)
func dMultiplied_wrtAll(
with other: Float
) -> (value: Float, pullback: (Float) -> (CustomParameter, Float)) {
return (multiplied(with: other),
{ [x] v in (CustomParameter(x: other.clamped(to: -10.0...10.0) * v),
x.clamped(to: -10.0...10.0) * v) })
}
@derivative(of: multiplied_constSelf, wrt: other)
func dMultiplied_wrtOther(
with other: Float
) -> (value: Float, pullback: (Float) -> Float) {
let (r, pb) = dMultiplied_wrtAll(with: other)
return (r, { v in pb(v).1 })
}
@derivative(of: multiplied_constOther, wrt: self)
func dMultiplied_wrtSelf(
with other: Float
) -> (value: Float, pullback: (Float) -> CustomParameter) {
let (r, pb) = dMultiplied_wrtAll(with: other)
return (r, { v in pb(v).0 })
}
@differentiable(reverse)
static func multiply(_ lhs: CustomParameter, _ rhs: CustomParameter)
-> Float {
return lhs.x * rhs.x
}
@differentiable(reverse, wrt: (rhs))
static func multiply_constLhs(_ lhs: CustomParameter, _ rhs: CustomParameter) -> Float {
return lhs.x * rhs.x
}
@derivative(of: multiply)
static func dMultiply_wrtAll(_ lhs: CustomParameter,_ rhs: CustomParameter)
-> (value: Float, pullback: (Float) -> (CustomParameter, CustomParameter)) {
let result = multiply(lhs, rhs)
return (result, { v in (CustomParameter(x: rhs.x.clamped(to: -10.0...10.0) * v),
CustomParameter(x: lhs.x.clamped(to: -10.0...10.0) * v)) })
}
@derivative(of: multiply_constLhs, wrt: rhs)
static func dMultiply_wrtRhs(_ lhs: CustomParameter, _ rhs: CustomParameter)
-> (value: Float, pullback: (Float) -> CustomParameter) {
let (r, pb) = dMultiply_wrtAll(lhs, rhs)
return (r, { v in pb(v).1 })
}
}
MethodTests.test(
"instance method with custom adjoint, called from differentiated func"
) {
func f(_ p: CustomParameter) -> Float {
return 100 * p.squared()
}
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f))
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f))
}
MethodTests.test("instance method with generated adjoint, differentiated directly") {
// This is our current syntax for taking gradients of instance methods
// directly. If/when we develop nicer syntax for this, change this test.
func g(p: CustomParameter) -> Float { p.squared() }
expectEqual(CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: g))
expectEqual(CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: g))
}
MethodTests.test("static method with custom adjoint, called from differentiated func") {
func f(_ p: CustomParameter) -> Float {
return 100 * CustomParameter.squared(p: p)
}
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f))
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f))
}
MethodTests.test("static method with custom adjoint, differentiated directly") {
expectEqual(
CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: CustomParameter.squared))
expectEqual(
CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: CustomParameter.squared))
}
MethodTests.test("instance method with custom adjoint, wrt only self") {
func f(_ p: CustomParameter) -> Float {
return 100 * p.multiplied_constOther(with: 200)
}
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
}
MethodTests.test("instance method with custom adjoint, wrt only non-self") {
func f(_ other: Float) -> Float {
return 100 * CustomParameter(x: 200).multiplied_constSelf(with: other)
}
expectEqual(100 * 10, gradient(at: 1, of: f))
expectEqual(100 * 10, gradient(at: 2, of: f))
}
MethodTests.test("instance method with custom adjoint, wrt self and non-self") {
func g(p: CustomParameter, o: Float) -> Float { p.multiplied(with: o) }
expectEqual((CustomParameter(x: 5), 10), gradient(at: CustomParameter(x: 100), 5, of: g))
expectEqual((CustomParameter(x: 10), 5), gradient(at: CustomParameter(x: 5), 100, of: g))
}
MethodTests.test("static method with custom adjoint, wrt only lhs") {
func f(_ p: CustomParameter) -> Float {
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p)
}
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
}
MethodTests.test("static method with custom adjoint, wrt only rhs") {
func f(_ p: CustomParameter) -> Float {
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p)
}
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f))
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f))
}
MethodTests.test("static method with custom adjoint, wrt all") {
func f(_ a: CustomParameter, _ b: CustomParameter) -> Float {
return CustomParameter.multiply(a, b)
}
expectEqual((CustomParameter(x: 5), CustomParameter(x: 10)),
gradient(at: CustomParameter(x: 100), CustomParameter(x: 5), of: f))
expectEqual((CustomParameter(x: 10), CustomParameter(x: 5)),
gradient(at: CustomParameter(x: 5), CustomParameter(x: 100), of: f))
}
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
We cannot use `Tracked<T>` :(
struct CustomParameterTracked : Equatable {
let storedX: Tracked<Float>
@differentiable(reverse, wrt: (self))
var x: Tracked<Float> {
return storedX
}
init(x: Tracked<Float>) {
storedX = x
}
@derivative(of: x)
func vjpX() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
return (x, { dx in CustomParameterTracked(x: dx) })
}
}
extension CustomParameterTracked : Differentiable, AdditiveArithmetic {
typealias TangentVector = CustomParameterTracked
typealias Scalar = Tracked<Float>
typealias Shape = ()
init(repeating repeatedValue: Tracked<Float>, shape: ()) {
self.init(x: repeatedValue)
}
static func + (lhs: CustomParameterTracked, rhs: CustomParameterTracked) -> CustomParameterTracked {
return CustomParameterTracked(x: lhs.x + rhs.x)
}
static func - (lhs: CustomParameterTracked, rhs: CustomParameterTracked) -> CustomParameterTracked {
return CustomParameterTracked(x: lhs.x - rhs.x)
}
static func * (lhs: Scalar, rhs: CustomParameterTracked) -> CustomParameterTracked {
return CustomParameterTracked(x: lhs * rhs.x)
}
static var zero: CustomParameterTracked { return CustomParameterTracked(x: 0) }
}
extension Tracked where T : FloatingPoint {
func clamped(to limits: ClosedRange<Tracked<T>>) -> Tracked<T> {
return min(max(self, limits.lowerBound), limits.upperBound)
}
}
extension CustomParameterTracked {
@differentiable(reverse, wrt: (self))
func squared() -> Tracked<Float> {
return x * x
}
@derivative(of: squared)
func dSquared() -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
return (squared(), { [x] v in CustomParameterTracked(x: (2 * x).clamped(to: -10.0...10.0) * v) })
}
@differentiable(reverse)
static func squared(p: CustomParameterTracked) -> Tracked<Float> {
return p.x * p.x
}
@derivative(of: squared)
static func dSquared(
_ p: CustomParameterTracked
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
return (p.x * p.x, { v in CustomParameterTracked(x: (2 * p.x).clamped(to: -10.0...10.0) * v) })
}
// There is currently no way to define multiple custom VJPs wrt different
// parameters on the same func, so we define a copy of this func per adjoint.
@differentiable(reverse, wrt: (self, other)) @differentiable(reverse, wrt: (self, other))
func multiplied(with other: Tracked<Float>) -> Tracked<Float> { func multiplied(with other: Tracked<Float>) -> Tracked<Float> {
return x * other return x * other
@@ -259,9 +626,9 @@ extension CustomParameter {
@derivative(of: multiplied) @derivative(of: multiplied)
func dMultiplied_wrtAll( func dMultiplied_wrtAll(
with other: Tracked<Float> with other: Tracked<Float>
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameter, Tracked<Float>)) { ) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameterTracked, Tracked<Float>)) {
return (multiplied(with: other), return (multiplied(with: other),
{ [x] v in (CustomParameter(x: other.clamped(to: -10.0...10.0) * v), { [x] v in (CustomParameterTracked(x: other.clamped(to: -10.0...10.0) * v),
x.clamped(to: -10.0...10.0) * v) }) x.clamped(to: -10.0...10.0) * v) })
} }
@@ -276,33 +643,33 @@ extension CustomParameter {
@derivative(of: multiplied_constOther, wrt: self) @derivative(of: multiplied_constOther, wrt: self)
func dMultiplied_wrtSelf( func dMultiplied_wrtSelf(
with other: Tracked<Float> with other: Tracked<Float>
) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) { ) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
let (r, pb) = dMultiplied_wrtAll(with: other) let (r, pb) = dMultiplied_wrtAll(with: other)
return (r, { v in pb(v).0 }) return (r, { v in pb(v).0 })
} }
@differentiable(reverse) @differentiable(reverse)
static func multiply(_ lhs: CustomParameter, _ rhs: CustomParameter) static func multiply(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked)
-> Tracked<Float> { -> Tracked<Float> {
return lhs.x * rhs.x return lhs.x * rhs.x
} }
@differentiable(reverse, wrt: (rhs)) @differentiable(reverse, wrt: (rhs))
static func multiply_constLhs(_ lhs: CustomParameter, _ rhs: CustomParameter) -> Tracked<Float> { static func multiply_constLhs(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked) -> Tracked<Float> {
return lhs.x * rhs.x return lhs.x * rhs.x
} }
@derivative(of: multiply) @derivative(of: multiply)
static func dMultiply_wrtAll(_ lhs: CustomParameter,_ rhs: CustomParameter) static func dMultiply_wrtAll(_ lhs: CustomParameterTracked,_ rhs: CustomParameterTracked)
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameter, CustomParameter)) { -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (CustomParameterTracked, CustomParameterTracked)) {
let result = multiply(lhs, rhs) let result = multiply(lhs, rhs)
return (result, { v in (CustomParameter(x: rhs.x.clamped(to: -10.0...10.0) * v), return (result, { v in (CustomParameterTracked(x: rhs.x.clamped(to: -10.0...10.0) * v),
CustomParameter(x: lhs.x.clamped(to: -10.0...10.0) * v)) }) CustomParameterTracked(x: lhs.x.clamped(to: -10.0...10.0) * v)) })
} }
@derivative(of: multiply_constLhs, wrt: rhs) @derivative(of: multiply_constLhs, wrt: rhs)
static func dMultiply_wrtRhs(_ lhs: CustomParameter, _ rhs: CustomParameter) static func dMultiply_wrtRhs(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked)
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameter) { -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> CustomParameterTracked) {
let (r, pb) = dMultiply_wrtAll(lhs, rhs) let (r, pb) = dMultiply_wrtAll(lhs, rhs)
return (r, { v in pb(v).1 }) return (r, { v in pb(v).1 })
} }
@@ -311,82 +678,83 @@ extension CustomParameter {
MethodTests.testWithLeakChecking( MethodTests.testWithLeakChecking(
"instance method with custom adjoint, called from differentiated func" "instance method with custom adjoint, called from differentiated func"
) { ) {
func f(_ p: CustomParameter) -> Tracked<Float> { func f(_ p: CustomParameterTracked) -> Tracked<Float> {
return 100 * p.squared() return 100 * p.squared()
} }
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f)) expectEqual(CustomParameterTracked(x: 4 * 100), gradient(at: CustomParameterTracked(x: 2), of: f))
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f)) expectEqual(CustomParameterTracked(x: 10 * 100), gradient(at: CustomParameterTracked(x: 20), of: f))
} }
MethodTests.testWithLeakChecking("instance method with generated adjoint, differentiated directly") { MethodTests.testWithLeakChecking("instance method with generated adjoint, differentiated directly") {
// This is our current syntax for taking gradients of instance methods // This is our current syntax for taking gradients of instance methods
// directly. If/when we develop nicer syntax for this, change this test. // directly. If/when we develop nicer syntax for this, change this test.
func g(p: CustomParameter) -> Tracked<Float> { p.squared() } func g(p: CustomParameterTracked) -> Tracked<Float> { p.squared() }
expectEqual(CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: g)) expectEqual(CustomParameterTracked(x: 4), gradient(at: CustomParameterTracked(x: 2), of: g))
expectEqual(CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: g)) expectEqual(CustomParameterTracked(x: 10), gradient(at: CustomParameterTracked(x: 20), of: g))
} }
MethodTests.testWithLeakChecking("static method with custom adjoint, called from differentiated func") { MethodTests.testWithLeakChecking("static method with custom adjoint, called from differentiated func") {
func f(_ p: CustomParameter) -> Tracked<Float> { func f(_ p: CustomParameterTracked) -> Tracked<Float> {
return 100 * CustomParameter.squared(p: p) return 100 * CustomParameterTracked.squared(p: p)
} }
expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f)) expectEqual(CustomParameterTracked(x: 4 * 100), gradient(at: CustomParameterTracked(x: 2), of: f))
expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f)) expectEqual(CustomParameterTracked(x: 10 * 100), gradient(at: CustomParameterTracked(x: 20), of: f))
} }
MethodTests.testWithLeakChecking("static method with custom adjoint, differentiated directly") { MethodTests.testWithLeakChecking("static method with custom adjoint, differentiated directly") {
expectEqual( expectEqual(
CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: CustomParameter.squared)) CustomParameterTracked(x: 4), gradient(at: CustomParameterTracked(x: 2), of: CustomParameterTracked.squared))
expectEqual( expectEqual(
CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: CustomParameter.squared)) CustomParameterTracked(x: 10), gradient(at: CustomParameterTracked(x: 20), of: CustomParameterTracked.squared))
} }
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only self") { MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only self") {
func f(_ p: CustomParameter) -> Tracked<Float> { func f(_ p: CustomParameterTracked) -> Tracked<Float> {
return 100 * p.multiplied_constOther(with: 200) return 100 * p.multiplied_constOther(with: 200)
} }
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f)) expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f))
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f)) expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f))
} }
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only non-self") { MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only non-self") {
func f(_ other: Tracked<Float>) -> Tracked<Float> { func f(_ other: Tracked<Float>) -> Tracked<Float> {
return 100 * CustomParameter(x: 200).multiplied_constSelf(with: other) return 100 * CustomParameterTracked(x: 200).multiplied_constSelf(with: other)
} }
expectEqual(100 * 10, gradient(at: 1, of: f)) expectEqual(100 * 10, gradient(at: 1, of: f))
expectEqual(100 * 10, gradient(at: 2, of: f)) expectEqual(100 * 10, gradient(at: 2, of: f))
} }
MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt self and non-self") { MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt self and non-self") {
func g(p: CustomParameter, o: Tracked<Float>) -> Tracked<Float> { p.multiplied(with: o) } func g(p: CustomParameterTracked, o: Tracked<Float>) -> Tracked<Float> { p.multiplied(with: o) }
expectEqual((CustomParameter(x: 5), 10), gradient(at: CustomParameter(x: 100), 5, of: g)) expectEqual((CustomParameterTracked(x: 5), 10), gradient(at: CustomParameterTracked(x: 100), 5, of: g))
expectEqual((CustomParameter(x: 10), 5), gradient(at: CustomParameter(x: 5), 100, of: g)) expectEqual((CustomParameterTracked(x: 10), 5), gradient(at: CustomParameterTracked(x: 5), 100, of: g))
} }
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only lhs") { MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only lhs") {
func f(_ p: CustomParameter) -> Tracked<Float> { func f(_ p: CustomParameterTracked) -> Tracked<Float> {
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p) return 100 * CustomParameterTracked.multiply_constLhs(CustomParameterTracked(x: 200), p)
} }
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f)) expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f))
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f)) expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f))
} }
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only rhs") { MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only rhs") {
func f(_ p: CustomParameter) -> Tracked<Float> { func f(_ p: CustomParameterTracked) -> Tracked<Float> {
return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p) return 100 * CustomParameterTracked.multiply_constLhs(CustomParameterTracked(x: 200), p)
} }
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f)) expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f))
expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f)) expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f))
} }
MethodTests.testWithLeakChecking("static method with custom adjoint, wrt all") { MethodTests.testWithLeakChecking("static method with custom adjoint, wrt all") {
func f(_ a: CustomParameter, _ b: CustomParameter) -> Tracked<Float> { func f(_ a: CustomParameterTracked, _ b: CustomParameterTracked) -> Tracked<Float> {
return CustomParameter.multiply(a, b) return CustomParameterTracked.multiply(a, b)
} }
expectEqual((CustomParameter(x: 5), CustomParameter(x: 10)), expectEqual((CustomParameterTracked(x: 5), CustomParameterTracked(x: 10)),
gradient(at: CustomParameter(x: 100), CustomParameter(x: 5), of: f)) gradient(at: CustomParameterTracked(x: 100), CustomParameterTracked(x: 5), of: f))
expectEqual((CustomParameter(x: 10), CustomParameter(x: 5)), expectEqual((CustomParameterTracked(x: 10), CustomParameterTracked(x: 5)),
gradient(at: CustomParameter(x: 5), CustomParameter(x: 100), of: f)) gradient(at: CustomParameterTracked(x: 5), CustomParameterTracked(x: 100), of: f))
} }
*/
runAllTests() runAllTests()

View File

@@ -1,4 +1,4 @@
// RUN: %target-run-simple-swift 1// RUN: %target-run-simple-swift
// REQUIRES: executable_test // REQUIRES: executable_test
import StdlibUnittest import StdlibUnittest
@@ -6,7 +6,19 @@ import DifferentiationUnittest
var RepeatedCallsTests = TestSuite("RepeatedCalls") var RepeatedCallsTests = TestSuite("RepeatedCalls")
RepeatedCallsTests.testWithLeakChecking("Repeat") { RepeatedCallsTests.test("Repeat") {
func mul2(_ x: Float) -> Float {
return 2 * x
}
func mul4(_ x: Float) -> Float {
return mul2(mul2(x))
}
expectEqual(4, gradient(at: 0, of: mul4))
}
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
We cannot use `Tracked<T>` :(
RepeatedCallsTests.testWithLeakChecking("Repeat-Tracked") {
func mul2(_ x: Tracked<Float>) -> Tracked<Float> { func mul2(_ x: Tracked<Float>) -> Tracked<Float> {
return 2 * x return 2 * x
} }
@@ -15,5 +27,6 @@ RepeatedCallsTests.testWithLeakChecking("Repeat") {
} }
expectEqual(4, gradient(at: 0, of: mul4)) expectEqual(4, gradient(at: 0, of: mul4))
} }
*/
runAllTests() runAllTests()

View File

@@ -266,6 +266,9 @@ SimpleMathTests.test("TupleMutation") {
} }
// Tests TF-321. // Tests TF-321.
/* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed
We cannot use `Tracked<T>` :(
SimpleMathTests.test("TupleNonDifferentiableElements") { SimpleMathTests.test("TupleNonDifferentiableElements") {
// TF-964: Test tuple with non-tuple-typed adjoint value. // TF-964: Test tuple with non-tuple-typed adjoint value.
func tupleLet(_ x: Tracked<Float>) -> Tracked<Float> { func tupleLet(_ x: Tracked<Float>) -> Tracked<Float> {
@@ -309,6 +312,51 @@ SimpleMathTests.test("TupleNonDifferentiableElements") {
} }
expectEqual((3, 1), valueWithGradient(at: 3, of: wrapper)) expectEqual((3, 1), valueWithGradient(at: 3, of: wrapper))
} }
*/
SimpleMathTests.test("TupleNonDifferentiableElementsNotTracked") {
// TF-964: Test tuple with non-tuple-typed adjoint value.
func tupleLet(_ x: Float) -> Float {
let tuple = (2 * x, 1)
return tuple.0
}
expectEqual((8, 2), valueWithGradient(at: 4, of: tupleLet))
func tupleVar(_ x: Float) -> Float {
var tuple = (x, 1)
tuple.0 = x
tuple.1 = 1
return tuple.0
}
expectEqual((3, 1), valueWithGradient(at: 3, of: tupleVar))
func nested(_ x: Float) -> Float {
// Convoluted function computing `x * x`.
var tuple: (Int, (Int, Float), Float) = (1, (1, 0), 0)
tuple.0 = 1
tuple.1.0 = 1
tuple.1.1 = x
tuple.2 = x
return tuple.1.1 * tuple.2
}
expectEqual((16, 8), valueWithGradient(at: 4, of: nested))
struct Wrapper<T> {
@differentiable(reverse where T : Differentiable)
func baz(_ x: T) -> T {
var tuple = (1, 1, x, 1)
tuple.0 = 1
tuple.2 = x
tuple.3 = 1
return tuple.2
}
}
func wrapper(_ x: Float) -> Float {
let w = Wrapper<Float>()
return w.baz(x)
}
expectEqual((3, 1), valueWithGradient(at: 3, of: wrapper))
}
// Tests TF-21. // Tests TF-21.
SimpleMathTests.test("StructMemberwiseInitializer") { SimpleMathTests.test("StructMemberwiseInitializer") {

View File

@@ -21,9 +21,10 @@ func test(arr: [any P]) {
} }
} }
// Make sure that we don't pick a concrete `(AnyHashable, AnyHashable) -> Bool` overload.
// CHECK: sil private [ossa] @$s34anyhashable_and_operator_filtering4test3arrySayAA1P_pG_tFyAaD_pXEfU_ // CHECK: sil private [ossa] @$s34anyhashable_and_operator_filtering4test3arrySayAA1P_pG_tFyAaD_pXEfU_
// CHECK: [[LHS_ARG:%.*]] = alloc_stack $E // CHECK: [[LHS_ARG:%.*]] = alloc_stack $E
// CHECK: [[RHS_ARG:%.*]] = alloc_stack $E // CHECK: [[RHS_ARG:%.*]] = alloc_stack $E
// CHECK: function_ref == infix<A>(_:_:) // CHECK: [[GENERIC_OP:%.*]] = witness_method $E, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
// CHECK-NEXT: [[GENERIC_OP:%.*]] = function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF : $@convention(thin) <τ_0_0 where τ_0_0 : RawRepresentable, τ_0_0.RawValue : Equatable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> Bool // CHECK-NEXT: apply [[GENERIC_OP]]<E>([[LHS_ARG]], [[RHS_ARG]], {{.*}})
// CHECK-NEXT: apply [[GENERIC_OP]]<E>([[LHS_ARG]], [[RHS_ARG]])

View File

@@ -8,9 +8,8 @@ struct Value: Equatable, ExpressibleByNilLiteral {
} }
// CHECK-LABEL: sil hidden [ossa] @$s13rdar1580631514test1vyAA5ValueV_tF : $@convention(thin) (Value) -> () // CHECK-LABEL: sil hidden [ossa] @$s13rdar1580631514test1vyAA5ValueV_tF : $@convention(thin) (Value) -> ()
// function_ref static Value.__derived_struct_equals(_:_:) // CHECK: [[EQUALS_REF:%.*]] = witness_method $Value, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
// CHECK: [[EQUALS_REF:%.*]] = function_ref @$s13rdar1580631515ValueV23__derived_struct_equalsySbAC_ACtFZ // CHECK-NEXT: apply [[EQUALS_REF]]<Value>({{.*}})
// CHECK-NEXT: apply [[EQUALS_REF]](%0, {{.*}})
func test(v: Value) { func test(v: Value) {
_ = v == nil _ = v == nil
} }

View File

@@ -12,6 +12,6 @@ func testFoo() {
// CHECK: [[@LINE+1]]:7 | instance-method/Swift | hash(into:) | s:14swift_ide_test9CustomFooV4hash4intoys6HasherVz_tF | {{.*}}Ref // CHECK: [[@LINE+1]]:7 | instance-method/Swift | hash(into:) | s:14swift_ide_test9CustomFooV4hash4intoys6HasherVz_tF | {{.*}}Ref
f.hash(into: &hasher) f.hash(into: &hasher)
hasher.finalize() hasher.finalize()
// CHECK: [[@LINE+1]]:11 | static-method/Swift | __derived_struct_equals(_:_:) | s:14swift_ide_test9CustomFooV23__derived_struct_equalsySbAC_ACtFZ | {{.*}}Ref // CHECK: [[@LINE+1]]:11 | static-method/infix-operator/Swift | ==(_:_:) | s:SQ2eeoiySbx_xtFZ | {{.*}}Ref
_ = f == CustomFoo(a: 0, b: "b") _ = f == CustomFoo(a: 0, b: "b")
} }

View File

@@ -48,15 +48,15 @@ public enum Alphabet : String {
// CHECK-LABEL: sil [ossa] @$s4main14check_alphabetySiAA8AlphabetOF : $@convention(thin) (Alphabet) -> Int { // CHECK-LABEL: sil [ossa] @$s4main14check_alphabetySiAA8AlphabetOF : $@convention(thin) (Alphabet) -> Int {
public func check_alphabet(_ state : Alphabet) -> Int { public func check_alphabet(_ state : Alphabet) -> Int {
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF // FRAGILE: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF // RESILIENT: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
return state == .E ? 1 : 0 return state == .E ? 1 : 0
} }
// CHECK-LABEL: sil [ossa] @$s4main9compareItySbAA8AlphabetO_ADtF : $@convention(thin) (Alphabet, Alphabet) -> Bool { // CHECK-LABEL: sil [ossa] @$s4main9compareItySbAA8AlphabetO_ADtF : $@convention(thin) (Alphabet, Alphabet) -> Bool {
public func compareIt(_ state : Alphabet, _ rhs: Alphabet) -> Bool { public func compareIt(_ state : Alphabet, _ rhs: Alphabet) -> Bool {
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF // FRAGILE: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF // RESILIENT: witness_method $Alphabet, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
return state == rhs return state == rhs
} }
@@ -67,14 +67,14 @@ public enum AlphabetInt : Int {
// CHECK-LABEL: sil [ossa] @$s4main18check_alphabet_intySiAA11AlphabetIntOF : $@convention(thin) (AlphabetInt) -> Int { // CHECK-LABEL: sil [ossa] @$s4main18check_alphabet_intySiAA11AlphabetIntOF : $@convention(thin) (AlphabetInt) -> Int {
public func check_alphabet_int(_ state : AlphabetInt) -> Int { public func check_alphabet_int(_ state : AlphabetInt) -> Int {
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF // FRAGILE: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF // RESILIENT: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
return state == .E ? 1 : 0 return state == .E ? 1 : 0
} }
// CHECK-LABEL: sil [ossa] @$s4main9compareItySbAA11AlphabetIntO_ADtF : $@convention(thin) (AlphabetInt, AlphabetInt) -> Bool { // CHECK-LABEL: sil [ossa] @$s4main9compareItySbAA11AlphabetIntO_ADtF : $@convention(thin) (AlphabetInt, AlphabetInt) -> Bool {
public func compareIt(_ state : AlphabetInt, _ rhs: AlphabetInt) -> Bool { public func compareIt(_ state : AlphabetInt, _ rhs: AlphabetInt) -> Bool {
// FRAGILE: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF // FRAGILE: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
// RESILIENT: function_ref @$ss2eeoiySbx_xtSYRzSQ8RawValueRpzlF // RESILIENT: witness_method $AlphabetInt, #Equatable."==" : <Self where Self : Equatable> (Self.Type) -> (Self, Self) -> Bool
return state == rhs return state == rhs
} }

View File

@@ -1,4 +1,7 @@
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s // RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=SILGEN
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s --check-prefix=OPTIMIZED
// Operators are no longer devirtualized at AST level, it's done during SIL optimization.
infix operator +++ infix operator +++
@@ -11,9 +14,13 @@ struct Branch : Twig {
static func doIt(_: Branch, _: Branch) {} static func doIt(_: Branch, _: Branch) {}
} }
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () { // SILGEN-LABEL: sil hidden [ossa] @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () {
// CHECK: function_ref @$s18protocol_operators6BranchV4doItyyAC_ACtFZ : $@convention(method) (Branch, Branch, @thin Branch.Type) -> () // SILGEN: witness_method $Branch, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> ()
// CHECK: return // SILGEN: return
// OPTIMIZED-LABEL: sil hidden @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () {
// OPTIMIZED: function_ref @$s18protocol_operators6BranchV4doItyyAC_ACtFZ
// OPTIMIZED: return
func useBranch(_ b: Branch) { func useBranch(_ b: Branch) {
b +++ b b +++ b
} }
@@ -28,11 +35,17 @@ class Stuck : Stick, ExpressibleByIntegerLiteral {
required init(integerLiteral: Int) {} required init(integerLiteral: Int) {}
} }
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () { // SILGEN-LABEL: sil hidden [ossa] @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () {
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> () // SILGEN: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> () // SILGEN: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
// CHECK: witness_method $Stuck, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // SILGEN: witness_method $Stuck, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: return // SILGEN: return
// OPTIMIZED-LABEL: sil hidden @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () {
// OPTIMIZED: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
// OPTIMIZED: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
// OPTIMIZED: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
// OPTIMIZED: return
func useStick(_ a: Stuck, _ b: Stick) { func useStick(_ a: Stuck, _ b: Stick) {
_ = a +++ b _ = a +++ b
_ = b +++ b _ = b +++ b
@@ -49,10 +62,15 @@ class Rope : Twine<Int>, ExpressibleByIntegerLiteral {
required init(integerLiteral: Int) {} required init(integerLiteral: Int) {}
} }
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () { // SILGEN-LABEL: sil hidden [ossa] @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () {
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> () // SILGEN: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> () // SILGEN: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
// CHECK: witness_method $Rope, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // SILGEN: witness_method $Rope, #Twig."+++" : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// OPTIMIZED-LABEL: sil hidden @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () {
// OPTIMIZED: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
// OPTIMIZED: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
// OPTIMIZED: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
func useRope(_ r: Rope, _ s: Rope) { func useRope(_ r: Rope, _ s: Rope) {
_ = r +++ s _ = r +++ s
_ = s +++ s _ = s +++ s