mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
203 lines
8.6 KiB
C++
203 lines
8.6 KiB
C++
//===--- TangentBuilder.cpp - Tangent SIL builder ------------*- C++ -*----===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines a helper class for emitting tangent code for automatic
|
|
// differentiation.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define DEBUG_TYPE "differentiation"
|
|
|
|
#include "swift/AST/ConformanceLookup.h"
|
|
#include "swift/Basic/Assertions.h"
|
|
#include "swift/SILOptimizer/Differentiation/TangentBuilder.h"
|
|
#include "swift/SILOptimizer/Differentiation/ADContext.h"
|
|
|
|
namespace swift {
|
|
namespace autodiff {
|
|
|
|
void TangentBuilder::emitZeroIntoBuffer(SILLocation loc, SILValue buffer,
|
|
IsInitialization_t isInit) {
|
|
if (!isInit)
|
|
emitDestroyAddr(loc, buffer);
|
|
if (auto tupleType = buffer->getType().getAs<TupleType>()) {
|
|
for (unsigned i : range(tupleType->getNumElements())) {
|
|
auto *eltAddr = createTupleElementAddr(loc, buffer, i);
|
|
emitZeroIntoBuffer(loc, eltAddr, IsInitialization);
|
|
}
|
|
return;
|
|
}
|
|
// Look up conformance to `AdditiveArithmetic`.
|
|
auto *additiveArithmeticProto = adContext.getAdditiveArithmeticProtocol();
|
|
auto astType = buffer->getType().getASTType();
|
|
auto confRef = lookupConformance(astType, additiveArithmeticProto);
|
|
assert(!confRef.isInvalid() && "Missing conformance to `AdditiveArithmetic`");
|
|
SILDeclRef accessorDeclRef(adContext.getAdditiveArithmeticZeroGetter(),
|
|
SILDeclRef::Kind::Func);
|
|
auto silFnType = getModule().Types.getConstantType(
|
|
getTypeExpansionContext(), accessorDeclRef);
|
|
// %wm = witness_method ...
|
|
auto *getter = createWitnessMethod(
|
|
loc, astType, confRef, accessorDeclRef, silFnType);
|
|
// %metatype = metatype $T
|
|
auto metatypeType = CanMetatypeType::get(astType,
|
|
MetatypeRepresentation::Thick);
|
|
auto metatype = createMetatype(
|
|
loc, SILType::getPrimitiveObjectType(metatypeType));
|
|
auto subMap = SubstitutionMap::getProtocolSubstitutions(
|
|
additiveArithmeticProto, astType, confRef);
|
|
createApply(loc, getter, subMap, {buffer, metatype});
|
|
emitDestroyValueOperation(loc, getter);
|
|
}
|
|
|
|
SILValue TangentBuilder::emitZero(SILLocation loc, CanType type) {
|
|
auto silType = getModule().Types.getLoweredLoadableType(
|
|
type, TypeExpansionContext::minimal(), getModule());
|
|
auto tempAllocLoc = RegularLocation::getAutoGeneratedLocation();
|
|
auto *alloc = createAllocStack(tempAllocLoc, silType);
|
|
emitZeroIntoBuffer(loc, alloc, IsInitialization);
|
|
auto zeroValue = emitLoadValueOperation(
|
|
loc, alloc, LoadOwnershipQualifier::Take);
|
|
createDeallocStack(loc, alloc);
|
|
return zeroValue;
|
|
}
|
|
|
|
void TangentBuilder::emitInPlaceAdd(
|
|
SILLocation loc, SILValue destinationBuffer, SILValue operand) {
|
|
assert(destinationBuffer->getType().isAddress());
|
|
auto type = destinationBuffer->getType();
|
|
if (auto tupleType = type.getAs<TupleType>()) {
|
|
for (unsigned i : range(tupleType->getNumElements())) {
|
|
auto *eltDestAddr = createTupleElementAddr(loc, destinationBuffer, i);
|
|
switch (operand->getType().getCategory()) {
|
|
case SILValueCategory::Address: {
|
|
auto *eltOperand = createTupleElementAddr(loc, operand, i);
|
|
emitInPlaceAdd(loc, eltDestAddr, eltOperand);
|
|
break;
|
|
}
|
|
case SILValueCategory::Object: {
|
|
auto borrowedOp = emitBeginBorrowOperation(loc, operand);
|
|
auto eltOperand = emitTupleExtract(loc, borrowedOp, i);
|
|
emitInPlaceAdd(loc, eltDestAddr, eltOperand);
|
|
emitEndBorrowOperation(loc, borrowedOp);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
// Call the combiner function and return.
|
|
auto astType = type.getASTType();
|
|
auto confRef = lookupConformance(
|
|
astType, adContext.getAdditiveArithmeticProtocol());
|
|
assert(!confRef.isInvalid() &&
|
|
"Missing conformance to `AdditiveArithmetic`");
|
|
SILDeclRef declRef(adContext.getPlusEqualDecl(), SILDeclRef::Kind::Func);
|
|
auto silFnTy = getModule().Types.getConstantType(
|
|
getTypeExpansionContext(), declRef);
|
|
// %0 = witness_method @+=
|
|
auto witnessMethod =
|
|
createWitnessMethod(loc, astType, confRef, declRef, silFnTy);
|
|
auto subMap = SubstitutionMap::getProtocolSubstitutions(
|
|
adContext.getAdditiveArithmeticProtocol(), astType, confRef);
|
|
// %1 = metatype $T.Type
|
|
auto metatypeType =
|
|
CanMetatypeType::get(astType, MetatypeRepresentation::Thick);
|
|
auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType);
|
|
auto metatype = createMetatype(loc, metatypeSILType);
|
|
// %2 = apply $0(%lhs, %rhs, %1)
|
|
createApply(loc, witnessMethod, subMap,
|
|
{destinationBuffer, operand, metatype});
|
|
emitDestroyValueOperation(loc, witnessMethod);
|
|
}
|
|
|
|
void TangentBuilder::emitAddIntoBuffer(SILLocation loc,
|
|
SILValue destinationBuffer,
|
|
SILValue lhsAddress,
|
|
SILValue rhsAddress) {
|
|
assert(lhsAddress->getType().getASTType() ==
|
|
rhsAddress->getType().getASTType() &&
|
|
"Adjoint values must have same type!");
|
|
assert(lhsAddress->getType().isAddress() &&
|
|
rhsAddress->getType().isAddress() &&
|
|
"Adjoint values must both have address types!");
|
|
auto type = lhsAddress->getType();
|
|
if (auto tupleType = type.getAs<TupleType>()) {
|
|
for (unsigned i : range(tupleType->getNumElements())) {
|
|
auto *destAddr = createTupleElementAddr(loc, destinationBuffer, i);
|
|
auto *eltAddrLHS = createTupleElementAddr(loc, lhsAddress, i);
|
|
auto *eltAddrRHS = createTupleElementAddr(loc, rhsAddress, i);
|
|
emitAddIntoBuffer(loc, destAddr, eltAddrLHS, eltAddrRHS);
|
|
}
|
|
return;
|
|
}
|
|
auto astType = type.getASTType();
|
|
auto *proto = adContext.getAdditiveArithmeticProtocol();
|
|
auto *combinerFuncDecl = adContext.getPlusDecl();
|
|
// Call the combiner function and return.
|
|
auto confRef = lookupConformance(astType, proto);
|
|
assert(!confRef.isInvalid() &&
|
|
"Missing conformance to `AdditiveArithmetic`");
|
|
SILDeclRef declRef(combinerFuncDecl, SILDeclRef::Kind::Func);
|
|
auto silFnTy = getModule().Types.getConstantType(
|
|
getTypeExpansionContext(), declRef);
|
|
// %0 = witness_method @+
|
|
auto witnessMethod =
|
|
createWitnessMethod(loc, astType, confRef, declRef, silFnTy);
|
|
auto subMap =
|
|
SubstitutionMap::getProtocolSubstitutions(proto, astType, confRef);
|
|
// %1 = metatype $T.Type
|
|
auto metatypeType =
|
|
CanMetatypeType::get(astType, MetatypeRepresentation::Thick);
|
|
auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType);
|
|
auto metatype = createMetatype(loc, metatypeSILType);
|
|
// %2 = apply %0(%result, %new, %old, %1)
|
|
createApply(loc, witnessMethod, subMap,
|
|
{destinationBuffer, rhsAddress, lhsAddress, metatype});
|
|
emitDestroyValueOperation(loc, witnessMethod);
|
|
}
|
|
|
|
SILValue TangentBuilder::emitAdd(SILLocation loc, SILValue lhs, SILValue rhs) {
|
|
LLVM_DEBUG(getADDebugStream() << "Emitting adjoint accumulation for lhs: "
|
|
<< lhs << " and rhs: " << rhs);
|
|
assert(lhs->getType() == rhs->getType() && "Adjoints must have equal types!");
|
|
assert(lhs->getType().isObject() && rhs->getType().isObject() &&
|
|
"Adjoint types must be both object types!");
|
|
auto type = lhs->getType();
|
|
auto lhsCopy = emitCopyValueOperation(loc, lhs);
|
|
auto rhsCopy = emitCopyValueOperation(loc, rhs);
|
|
// Allocate buffers for inputs and output.
|
|
auto tempAllocLoc = RegularLocation::getAutoGeneratedLocation();
|
|
auto *resultBuf = createAllocStack(tempAllocLoc, type);
|
|
auto *lhsBuf = createAllocStack(tempAllocLoc, type);
|
|
auto *rhsBuf = createAllocStack(tempAllocLoc, type);
|
|
// Initialize input buffers.
|
|
emitStoreValueOperation(loc, lhsCopy, lhsBuf,
|
|
StoreOwnershipQualifier::Init);
|
|
emitStoreValueOperation(loc, rhsCopy, rhsBuf,
|
|
StoreOwnershipQualifier::Init);
|
|
emitAddIntoBuffer(loc, resultBuf, lhsBuf, rhsBuf);
|
|
emitDestroyAddr(loc, lhsBuf);
|
|
emitDestroyAddr(loc, rhsBuf);
|
|
// Deallocate input buffers.
|
|
createDeallocStack(loc, rhsBuf);
|
|
createDeallocStack(loc, lhsBuf);
|
|
auto val = emitLoadValueOperation(loc, resultBuf,
|
|
LoadOwnershipQualifier::Take);
|
|
// Deallocate result buffer.
|
|
createDeallocStack(loc, resultBuf);
|
|
return val;
|
|
}
|
|
|
|
} // end namespace autodiff
|
|
} // end namespace swift
|