//===--- Common.cpp - Automatic differentiation common utils --*- C++ -*---===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // Automatic differentiation common utilities. // //===----------------------------------------------------------------------===// #include "swift/Basic/STLExtras.h" #define DEBUG_TYPE "differentiation" #include "swift/SIL/ApplySite.h" #include "swift/SILOptimizer/Differentiation/Common.h" #include "swift/AST/TypeCheckRequests.h" #include "swift/Basic/Assertions.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" namespace swift { namespace autodiff { raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; } //===----------------------------------------------------------------------===// // Helpers //===----------------------------------------------------------------------===// ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v) { // Find the `pointer_to_address` result, peering through `index_addr`. auto *ptai = dyn_cast(v); if (auto *iai = dyn_cast(v)) ptai = dyn_cast(iai->getOperand(0)); if (!ptai) return nullptr; auto *mdi = dyn_cast( ptai->getOperand()->getDefiningInstruction()); if (!mdi) return nullptr; // Return the `array.uninitialized_intrinsic` application, if it exists. if (auto *dti = dyn_cast( mdi->getValue()->getDefiningInstruction())) return ArraySemanticsCall(dti->getOperand(), semantics::ARRAY_UNINITIALIZED_INTRINSIC); return nullptr; } bool isSemanticMemberAccessor(SILFunction *original) { auto *dc = original->getDeclContext(); if (!dc) return false; auto *decl = dc->getAsDecl(); if (!decl) return false; auto *accessor = dyn_cast(decl); if (!accessor) return false; // Currently, only getters, setters and _modify accessors are supported. if (accessor->getAccessorKind() != AccessorKind::Get && accessor->getAccessorKind() != AccessorKind::Set && accessor->getAccessorKind() != AccessorKind::Modify) return false; // Accessor must come from a `var` declaration. auto *varDecl = dyn_cast(accessor->getStorage()); if (!varDecl) return false; // Return true for stored property accessors. if (varDecl->hasStorage() && varDecl->isInstanceMember()) return true; // Return true for properties that have attached property wrappers. if (varDecl->hasAttachedPropertyWrapper()) return true; // Otherwise, return false. // User-defined accessors can never be supported because they may use custom // logic that does not semantically perform a member access. return false; } bool hasSemanticMemberAccessorCallee(ApplySite applySite) { if (auto *FRI = dyn_cast(applySite.getCallee())) if (auto *F = FRI->getReferencedFunctionOrNull()) return isSemanticMemberAccessor(F); return false; } void forEachApplyDirectResult( FullApplySite applySite, llvm::function_ref resultCallback) { switch (applySite.getKind()) { case FullApplySiteKind::ApplyInst: { auto *ai = cast(applySite.getInstruction()); if (!ai->getType().is()) { resultCallback(ai); return; } if (auto *dti = ai->getSingleUserOfType()) for (auto directResult : dti->getResults()) resultCallback(directResult); break; } case FullApplySiteKind::BeginApplyInst: { auto *bai = cast(applySite.getInstruction()); for (auto directResult : bai->getResults()) resultCallback(directResult); break; } case FullApplySiteKind::TryApplyInst: { auto *tai = cast(applySite.getInstruction()); for (auto *succBB : tai->getSuccessorBlocks()) for (auto *arg : succBB->getArguments()) resultCallback(arg); break; } } } void collectAllFormalResultsInTypeOrder(SILFunction &function, SmallVectorImpl &results) { SILFunctionConventions convs(function.getLoweredFunctionType(), function.getModule()); auto indResults = function.getIndirectResults(); auto *retInst = cast(function.findReturnBB()->getTerminator()); auto retVal = retInst->getOperand(); SmallVector dirResults; if (auto *tupleInst = dyn_cast_or_null(retVal->getDefiningInstruction())) dirResults.append(tupleInst->getElements().begin(), tupleInst->getElements().end()); else dirResults.push_back(retVal); unsigned indResIdx = 0, dirResIdx = 0; for (auto &resInfo : convs.getResults()) results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++] : indResults[indResIdx++]); // Treat semantic result parameters as semantic results. // Append them` parameters after formal results. for (auto i : range(convs.getNumParameters())) { auto paramInfo = convs.getParameters()[i]; if (!paramInfo.isAutoDiffSemanticResult()) continue; auto *argument = function.getArgumentsWithoutIndirectResults()[i]; results.push_back(argument); } // Treat yields as semantic results. Note that we can only differentiate // @yield_once with simple control flow, so we can assume that the function // contains only a single `yield` instruction auto yieldIt = std::find_if(function.begin(), function.end(), [](const SILBasicBlock &BB) -> bool { const TermInst *TI = BB.getTerminator(); return isa(TI); }); if (yieldIt != function.end()) { auto *yieldInst = cast(yieldIt->getTerminator()); for (auto yield : yieldInst->getOperandValues()) results.push_back(yield); } } void collectAllDirectResultsInTypeOrder(SILFunction &function, SmallVectorImpl &results) { SILFunctionConventions convs(function.getLoweredFunctionType(), function.getModule()); auto *retInst = cast(function.findReturnBB()->getTerminator()); auto retVal = retInst->getOperand(); if (auto *tupleInst = dyn_cast(retVal)) results.append(tupleInst->getElements().begin(), tupleInst->getElements().end()); else results.push_back(retVal); } void collectAllActualResultsInTypeOrder( FullApplySite fai, ArrayRef extractedDirectResults, SmallVectorImpl &results) { auto calleeConvs = fai.getSubstCalleeConv(); unsigned indResIdx = 0, dirResIdx = 0; for (auto &resInfo : calleeConvs.getResults()) { results.push_back(resInfo.isFormalDirect() ? extractedDirectResults[dirResIdx++] : fai.getIndirectSILResults()[indResIdx++]); } } void collectMinimalIndicesForFunctionCall( FullApplySite ai, const AutoDiffConfig &parentConfig, const DifferentiableActivityInfo &activityInfo, SmallVectorImpl &results, SmallVectorImpl ¶mIndices, SmallVectorImpl &resultIndices) { auto calleeFnTy = ai.getSubstCalleeType(); auto calleeConvs = ai.getSubstCalleeConv(); // Parameter indices are indices (in the callee type signature) of parameter // arguments that are varied or are arguments. // Record all parameter indices in type order. unsigned currentParamIdx = 0; for (auto applyArg : ai.getArgumentsWithoutIndirectResults()) { if (activityInfo.isActive(applyArg, parentConfig)) paramIndices.push_back(currentParamIdx); ++currentParamIdx; } // Result indices are indices (in the callee type signature) of results that // are useful. SmallVector directResults; forEachApplyDirectResult(ai, [&](SILValue directResult) { directResults.push_back(directResult); }); auto indirectResults = ai.getIndirectSILResults(); // Record all results and result indices in type order. results.reserve(calleeFnTy->getNumResults()); unsigned dirResIdx = 0; unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult(); for (const auto &resAndIdx : enumerate(calleeConvs.getResults())) { const auto &res = resAndIdx.value(); unsigned idx = resAndIdx.index(); if (res.isFormalDirect()) { results.push_back(directResults[dirResIdx]); if (auto dirRes = directResults[dirResIdx]) if (dirRes && activityInfo.isActive(dirRes, parentConfig)) resultIndices.push_back(idx); ++dirResIdx; } else { results.push_back(indirectResults[indResIdx]); if (activityInfo.isActive(indirectResults[indResIdx], parentConfig)) resultIndices.push_back(idx); ++indResIdx; } } // Record all semantic result parameters as results. auto semanticResultParamResultIndex = calleeFnTy->getNumResults(); for (const auto ¶mAndIdx : enumerate(calleeConvs.getParameters())) { const auto ¶m = paramAndIdx.value(); if (!param.isAutoDiffSemanticResult()) continue; unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults(); results.push_back(ai.getArgument(idx)); resultIndices.push_back(semanticResultParamResultIndex++); } // Record all yields. While we do not have a way to represent direct yields // (_read accessors) we run activity analysis for them. These will be // diagnosed later. if (BeginApplyInst *bai = dyn_cast(*ai)) { for (const auto &yieldAndIdx : enumerate(calleeConvs.getYields())) { results.push_back(bai->getYieldedValues()[yieldAndIdx.index()]); resultIndices.push_back(semanticResultParamResultIndex++); } } // Make sure the function call has active results. #ifndef NDEBUG assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults()); assert(llvm::any_of(results, [&](SILValue result) { return activityInfo.isActive(result, parentConfig); })); #endif } std::optional> findDebugLocationAndVariable(SILValue originalValue) { if (auto *asi = dyn_cast(originalValue)) return swift::transform(asi->getVarInfo(false), [&](SILDebugVariable var) { return std::make_pair(asi->getDebugLocation(), var); }); for (auto *use : originalValue->getUses()) { if (auto *dvi = dyn_cast(use->getUser())) return swift::transform(dvi->getVarInfo(false), [&](SILDebugVariable var) { // We need to drop `op_deref` here as we're transferring debug info // location from debug_value instruction (which describes how to get value) // into alloc_stack (which describes the location) if (var.DIExpr.startsWithDeref()) var.DIExpr.eraseElement(var.DIExpr.element_begin()); return std::make_pair(dvi->getDebugLocation(), var); }); } return std::nullopt; } //===----------------------------------------------------------------------===// // Diagnostic utilities //===----------------------------------------------------------------------===// SILLocation getValidLocation(SILValue v) { auto loc = v.getLoc(); if (loc.isNull() || loc.getSourceLoc().isInvalid()) loc = v->getFunction()->getLocation(); return loc; } SILLocation getValidLocation(SILInstruction *inst) { auto loc = inst->getLoc(); if (loc.isNull() || loc.getSourceLoc().isInvalid()) loc = inst->getFunction()->getLocation(); return loc; } //===----------------------------------------------------------------------===// // Tangent property lookup utilities //===----------------------------------------------------------------------===// VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField, CanType baseType, SILLocation loc, DifferentiationInvoker invoker) { auto &astCtx = context.getASTContext(); auto tanFieldInfo = evaluateOrDefault( astCtx.evaluator, TangentStoredPropertyRequest{originalField, baseType}, TangentPropertyInfo(nullptr)); // If no error, return the tangent property. if (tanFieldInfo) return tanFieldInfo.tangentProperty; // Otherwise, diagnose error and return nullptr. assert(tanFieldInfo.error); auto *parentDC = originalField->getDeclContext(); assert(parentDC->isTypeContext()); auto parentDeclName = parentDC->getSelfNominalTypeDecl()->getNameStr(); auto fieldName = originalField->getNameStr(); auto sourceLoc = loc.getSourceLoc(); switch (tanFieldInfo.error->kind) { case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty: llvm_unreachable( "`@noDerivative` stored property accesses should not be " "differentiated; activity analysis should not mark as varied"); case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable: context.emitNondifferentiabilityError( sourceLoc, invoker, diag::autodiff_stored_property_parent_not_differentiable, parentDeclName, fieldName); break; case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable: context.emitNondifferentiabilityError( sourceLoc, invoker, diag::autodiff_stored_property_not_differentiable, parentDeclName, fieldName, originalField->getInterfaceType()); break; case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct: context.emitNondifferentiabilityError( sourceLoc, invoker, diag::autodiff_stored_property_tangent_not_struct, parentDeclName, fieldName); break; case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound: context.emitNondifferentiabilityError( sourceLoc, invoker, diag::autodiff_stored_property_no_corresponding_tangent, parentDeclName, fieldName); break; case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType: context.emitNondifferentiabilityError( sourceLoc, invoker, diag::autodiff_tangent_property_wrong_type, parentDeclName, fieldName, tanFieldInfo.error->getType()); break; case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored: context.emitNondifferentiabilityError( sourceLoc, invoker, diag::autodiff_tangent_property_not_stored, parentDeclName, fieldName); break; } return nullptr; } VarDecl *getTangentStoredProperty(ADContext &context, SingleValueInstruction *projectionInst, CanType baseType, DifferentiationInvoker invoker) { assert(isa(projectionInst) || isa(projectionInst) || isa(projectionInst)); Projection proj(projectionInst); auto loc = getValidLocation(projectionInst); auto *field = proj.getVarDecl(projectionInst->getOperand(0)->getType()); return getTangentStoredProperty(context, field, baseType, loc, invoker); } //===----------------------------------------------------------------------===// // Code emission utilities //===----------------------------------------------------------------------===// SILValue joinElements(ArrayRef elements, SILBuilder &builder, SILLocation loc) { if (elements.size() == 1) return elements.front(); return builder.createTuple(loc, elements); } void extractAllElements(SILValue value, SILBuilder &builder, SmallVectorImpl &results) { auto tupleType = value->getType().getAs(); if (!tupleType) { results.push_back(value); return; } if (builder.hasOwnership()) { auto *dti = builder.createDestructureTuple(value.getLoc(), value); results.append(dti->getResults().begin(), dti->getResults().end()); return; } for (auto i : range(tupleType->getNumElements())) results.push_back(builder.createTupleExtract(value.getLoc(), value, i)); } SILValue emitMemoryLayoutSize( SILBuilder &builder, SILLocation loc, CanType type) { auto &ctx = builder.getASTContext(); auto id = ctx.getIdentifier(getBuiltinName(BuiltinValueKind::Sizeof)); auto *builtin = cast(getBuiltinValueDecl(ctx, id)); auto metatypeTy = SILType::getPrimitiveObjectType( CanMetatypeType::get(type, MetatypeRepresentation::Thin)); auto metatypeVal = builder.createMetatype(loc, metatypeTy); return builder.createBuiltin( loc, id, SILType::getBuiltinWordType(ctx), SubstitutionMap::get( builtin->getGenericSignature(), ArrayRef{type}, LookUpConformanceInModule()), {metatypeVal}); } SILValue emitProjectTopLevelSubcontext( SILBuilder &builder, SILLocation loc, SILValue context, SILType subcontextType) { assert(context->getOwnershipKind() == OwnershipKind::Guaranteed); auto &ctx = builder.getASTContext(); auto id = ctx.getIdentifier( getBuiltinName(BuiltinValueKind::AutoDiffProjectTopLevelSubcontext)); assert(context->getType() == SILType::getNativeObjectType(ctx)); auto *subcontextAddr = builder.createBuiltin( loc, id, SILType::getRawPointerType(ctx), SubstitutionMap(), {context}); return builder.createPointerToAddress( loc, subcontextAddr, subcontextType.getAddressType(), /*isStrict*/ true); } //===----------------------------------------------------------------------===// // Utilities for looking up derivatives of functions //===----------------------------------------------------------------------===// /// Returns the AbstractFunctionDecl corresponding to `F`. If there isn't one, /// returns `nullptr`. static AbstractFunctionDecl *findAbstractFunctionDecl(SILFunction *F) { auto *DC = F->getDeclContext(); if (!DC) return nullptr; auto *D = DC->getAsDecl(); if (!D) return nullptr; return dyn_cast(D); } SILDifferentiabilityWitness * getExactDifferentiabilityWitness(SILModule &module, SILFunction *original, IndexSubset *parameterIndices, IndexSubset *resultIndices) { for (auto *w : module.lookUpDifferentiabilityWitnessesForFunction( original->getName())) { if (w->getParameterIndices() == parameterIndices && w->getResultIndices() == resultIndices) return w; } return nullptr; } std::optional findMinimalDerivativeConfiguration(AbstractFunctionDecl *original, IndexSubset *parameterIndices, IndexSubset *&minimalASTParameterIndices) { std::optional minimalConfig = std::nullopt; auto configs = original->getDerivativeFunctionConfigurations(); for (auto &config : configs) { auto *silParameterIndices = autodiff::getLoweredParameterIndices( config.parameterIndices, original->getInterfaceType()->castTo()); if (silParameterIndices->getCapacity() < parameterIndices->getCapacity()) { silParameterIndices = silParameterIndices->extendingCapacity(original->getASTContext(), parameterIndices->getCapacity()); } // If all indices in `parameterIndices` are in `daParameterIndices`, and // it has fewer indices than our current candidate and a primitive VJP, // then `attr` is our new candidate. // // NOTE(TF-642): `attr` may come from a un-partial-applied function and // have larger capacity than the desired indices. We expect this logic to // go away when `partial_apply` supports `@differentiable` callees. if (silParameterIndices->isSupersetOf(parameterIndices->extendingCapacity( original->getASTContext(), silParameterIndices->getCapacity())) && // fewer parameters than before (!minimalConfig || silParameterIndices->getNumIndices() < minimalConfig->parameterIndices->getNumIndices())) { minimalASTParameterIndices = config.parameterIndices; minimalConfig = AutoDiffConfig(silParameterIndices, config.resultIndices, autodiff::getDifferentiabilityWitnessGenericSignature( original->getGenericSignature(), config.derivativeGenericSignature)); } } return minimalConfig; } SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( SILModule &module, SILFunction *original, DifferentiabilityKind kind, IndexSubset *parameterIndices, IndexSubset *resultIndices) { // Explicit differentiability witnesses only exist on SIL functions that come // from AST functions. auto *originalAFD = findAbstractFunctionDecl(original); if (!originalAFD) return nullptr; IndexSubset *minimalASTParameterIndices = nullptr; auto minimalConfig = findMinimalDerivativeConfiguration( originalAFD, parameterIndices, minimalASTParameterIndices); if (!minimalConfig) return nullptr; std::string originalName = original->getName().str(); // If original function requires a foreign entry point, use the foreign SIL // function to get or create the minimal differentiability witness. if (requiresForeignEntryPoint(originalAFD)) { originalName = SILDeclRef(originalAFD).asForeign().mangle(); original = module.lookUpFunction(SILDeclRef(originalAFD).asForeign()); } auto *existingWitness = module.lookUpDifferentiabilityWitness( {originalName, kind, *minimalConfig}); if (existingWitness) return existingWitness; assert(original->isExternalDeclaration() && "SILGen should create differentiability witnesses for all function " "definitions with explicit differentiable attributes"); return SILDifferentiabilityWitness::createDeclaration( module, // Witness for @_alwaysEmitIntoClient original function must be emitted, // otherwise a linker error would occur due to undefined reference to the // witness symbol. original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI : SILLinkage::PublicExternal, original, kind, minimalConfig->parameterIndices, minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature); } } // end namespace autodiff } // end namespace swift