diff --git a/include/swift/SILOptimizer/PassManager/Passes.def b/include/swift/SILOptimizer/PassManager/Passes.def index 83db39ea944..55f12b2d7d8 100644 --- a/include/swift/SILOptimizer/PassManager/Passes.def +++ b/include/swift/SILOptimizer/PassManager/Passes.def @@ -152,6 +152,9 @@ PASS(DiagnoseUnreachable, "diagnose-unreachable", "Diagnose Unreachable Code") PASS(DiagnosticConstantPropagation, "diagnostic-constant-propagation", "Constants Propagation for Diagnostics") +PASS(DifferentiabilityWitnessDevirtualizer, + "differentiability-witness-devirtualizer", + "Inlines Differentiability Witnesses") PASS(EagerSpecializer, "eager-specializer", "Eager Specialization via @_specialize") PASS(EarlyCodeMotion, "early-codemotion", diff --git a/lib/SILOptimizer/PassManager/PassPipeline.cpp b/lib/SILOptimizer/PassManager/PassPipeline.cpp index f2d717f4f4f..926d36aebc5 100644 --- a/lib/SILOptimizer/PassManager/PassPipeline.cpp +++ b/lib/SILOptimizer/PassManager/PassPipeline.cpp @@ -408,6 +408,11 @@ static void addPerfEarlyModulePassPipeline(SILPassPipelinePlan &P) { // Cleanup after SILGen: remove unneeded borrows/copies. P.addSemanticARCOpts(); + // Devirtualizes differentiability witnesses into functions that reference them. + // This unblocks many other passes' optimizations (e.g. inlining) and this is + // not blocked by any other passes' optimizations, so do it early. + P.addDifferentiabilityWitnessDevirtualizer(); + // Strip ownership from non-transparent functions. if (P.getOptions().StripOwnershipAfterSerialization) P.addNonTransparentFunctionOwnershipModelEliminator(); diff --git a/lib/SILOptimizer/Transforms/CMakeLists.txt b/lib/SILOptimizer/Transforms/CMakeLists.txt index 484d4280bf4..e060c0208c1 100644 --- a/lib/SILOptimizer/Transforms/CMakeLists.txt +++ b/lib/SILOptimizer/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ silopt_register_sources( DeadStoreElimination.cpp DestroyHoisting.cpp Devirtualizer.cpp + DifferentiabilityWitnessDevirtualizer.cpp GenericSpecializer.cpp MergeCondFail.cpp Outliner.cpp diff --git a/lib/SILOptimizer/Transforms/DifferentiabilityWitnessDevirtualizer.cpp b/lib/SILOptimizer/Transforms/DifferentiabilityWitnessDevirtualizer.cpp new file mode 100644 index 00000000000..0ffdc75826a --- /dev/null +++ b/lib/SILOptimizer/Transforms/DifferentiabilityWitnessDevirtualizer.cpp @@ -0,0 +1,71 @@ +//===--- DifferentiabilityWitnessDevirtualizer.cpp ------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// Devirtualizes `differentiability_witness_function` instructions into +// `function_ref` instructions for differentiability witness definitions. +// +//===----------------------------------------------------------------------===// + +#include "swift/SIL/SILBuilder.h" +#include "swift/SIL/SILFunction.h" +#include "swift/SIL/SILInstruction.h" +#include "swift/SILOptimizer/PassManager/Transforms.h" + +using namespace swift; + +namespace { +class DifferentiabilityWitnessDevirtualizer : public SILFunctionTransform { + + /// Returns true if any changes were made. + bool devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f); + + /// The entry point to the transformation. + void run() override { + if (devirtualizeDifferentiabilityWitnessesInFunction(*getFunction())) + invalidateAnalysis(SILAnalysis::InvalidationKind::CallsAndInstructions); + } +}; +} // end anonymous namespace + +bool DifferentiabilityWitnessDevirtualizer:: + devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f) { + bool changed = false; + llvm::SmallVector insts; + for (auto &bb : f) { + for (auto &inst : bb) { + auto *dfwi = dyn_cast(&inst); + if (!dfwi) + continue; + insts.push_back(dfwi); + } + } + for (auto *inst : insts) { + auto *witness = inst->getWitness(); + if (witness->isDeclaration()) + f.getModule().loadDifferentiabilityWitness(witness); + if (witness->isDeclaration()) + continue; + changed = true; + SILBuilderWithScope builder(inst); + auto kind = inst->getWitnessKind().getAsDerivativeFunctionKind(); + assert(kind.hasValue()); + auto *newInst = builder.createFunctionRefFor(inst->getLoc(), + witness->getDerivative(*kind)); + inst->replaceAllUsesWith(newInst); + inst->getParent()->erase(inst); + } + return changed; +} + +SILTransform *swift::createDifferentiabilityWitnessDevirtualizer() { + return new DifferentiabilityWitnessDevirtualizer(); +} diff --git a/test/AutoDiff/SILOptimizer/differentiability_witness_inlining.sil b/test/AutoDiff/SILOptimizer/differentiability_witness_inlining.sil new file mode 100644 index 00000000000..3637309bf21 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/differentiability_witness_inlining.sil @@ -0,0 +1,43 @@ +// RUN: %target-sil-opt -differentiability-witness-devirtualizer %s -enable-sil-verify-all | %FileCheck %s + +sil_stage raw + +import _Differentiation +import Swift +import Builtin + +sil_differentiability_witness [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float { + jvp: @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + vjp: @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +} + +sil_differentiability_witness [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float + +// This is an example of a witness that is available (via deserialization) +// even though it is not defined in the current module. +// witness for static Swift.Float.+ infix(Swift.Float, Swift.Float) -> Swift.Float +sil_differentiability_witness [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float + +sil @witness_defined_in_module : $@convention(thin) (Float) -> Float + +sil @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + +sil @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + +sil @witness_definition_not_available : $@convention(thin) (Float) -> Float + +sil public_external [transparent] [serialized] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float + +sil @test : $@convention(thin) (Float) -> () { +bb0(%0 : $Float): + %1 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float + // CHECK: %1 = function_ref @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + + %2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float + // CHECK: %2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float + + %3 = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float + // CHECK: %3 = function_ref @AD__$sSf1poiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) + + return undef : $() +}