ConstExtract: Refactor handling of AvailabilitySpec.

Soon, `AvailabilitySpec` will require that the `AvailabiltyDomain` it contains
be queried using a request that takes the `DeclContext` as input in order to
resolve the parsed domain name to an instance of `AvailabilityDomain`. The
constant extraction pipeline needed a bit of refactoring to thread a
`DeclContext` through to the place where it will be needed to execute the
query.

NFC.
This commit is contained in:
Allan Shortlidge
2025-02-17 15:08:35 -08:00
parent b1411b3cd8
commit 774248fcac
2 changed files with 90 additions and 55 deletions

View File

@@ -217,12 +217,24 @@ public:
///
class ConditionalMember : public BuilderMember {
public:
class AvailabilitySpec {
private:
AvailabilityDomain Domain;
llvm::VersionTuple Version;
public:
AvailabilitySpec(AvailabilityDomain Domain, llvm::VersionTuple Version)
: Domain(Domain), Version(Version) {}
AvailabilityDomain getDomain() const { return Domain; }
llvm::VersionTuple getVersion() const { return Version; }
};
ConditionalMember(MemberKind MemberKind,
std::vector<AvailabilitySpec> AvailabilityAttributes,
std::vector<AvailabilitySpec> AvailabilitySpecs,
std::vector<std::shared_ptr<BuilderMember>> IfElements,
std::vector<std::shared_ptr<BuilderMember>> ElseElements)
: BuilderMember(MemberKind),
AvailabilityAttributes(AvailabilityAttributes),
: BuilderMember(MemberKind), AvailabilitySpecs(AvailabilitySpecs),
IfElements(IfElements), ElseElements(ElseElements) {}
ConditionalMember(MemberKind MemberKind,
@@ -238,9 +250,8 @@ public:
(Kind == MemberKind::Optional);
}
std::optional<std::vector<AvailabilitySpec>>
getAvailabilityAttributes() const {
return AvailabilityAttributes;
std::optional<std::vector<AvailabilitySpec>> getAvailabilitySpecs() const {
return AvailabilitySpecs;
}
std::vector<std::shared_ptr<BuilderMember>> getIfElements() const {
return IfElements;
@@ -250,7 +261,7 @@ public:
}
private:
std::optional<std::vector<AvailabilitySpec>> AvailabilityAttributes;
std::optional<std::vector<AvailabilitySpec>> AvailabilitySpecs;
std::vector<std::shared_ptr<BuilderMember>> IfElements;
std::vector<std::shared_ptr<BuilderMember>> ElseElements;
};

View File

@@ -168,12 +168,15 @@ parseProtocolListFromFile(StringRef protocolListFilePath,
}
std::vector<std::shared_ptr<BuilderValue::BuilderMember>>
getResultBuilderMembersFromBraceStmt(BraceStmt *braceStmt);
getResultBuilderMembersFromBraceStmt(BraceStmt *braceStmt,
const DeclContext *declContext);
static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr);
static std::shared_ptr<CompileTimeValue>
extractCompileTimeValue(Expr *expr, const DeclContext *declContext);
static std::vector<FunctionParameter>
extractFunctionArguments(const ArgumentList *args) {
extractFunctionArguments(const ArgumentList *args,
const DeclContext *declContext) {
std::vector<FunctionParameter> parameters;
for (auto arg : *args) {
@@ -188,7 +191,8 @@ extractFunctionArguments(const ArgumentList *args) {
} else if (auto optionalInject = dyn_cast<InjectIntoOptionalExpr>(argExpr)) {
argExpr = optionalInject->getSubExpr();
}
parameters.push_back({label, type, extractCompileTimeValue(argExpr)});
parameters.push_back(
{label, type, extractCompileTimeValue(argExpr, declContext)});
}
return parameters;
@@ -224,7 +228,8 @@ static std::optional<std::string> extractRawLiteral(Expr *expr) {
return std::nullopt;
}
static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
static std::shared_ptr<CompileTimeValue>
extractCompileTimeValue(Expr *expr, const DeclContext *declContext) {
if (expr) {
switch (expr->getKind()) {
case ExprKind::BooleanLiteral:
@@ -247,7 +252,8 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
auto arrayExpr = cast<ArrayExpr>(expr);
std::vector<std::shared_ptr<CompileTimeValue>> elementValues;
for (const auto elementExpr : arrayExpr->getElements()) {
elementValues.push_back(extractCompileTimeValue(elementExpr));
elementValues.push_back(
extractCompileTimeValue(elementExpr, declContext));
}
return std::make_shared<ArrayValue>(elementValues);
}
@@ -256,7 +262,7 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
auto dictionaryExpr = cast<DictionaryExpr>(expr);
std::vector<std::shared_ptr<TupleValue>> tuples;
for (auto elementExpr : dictionaryExpr->getElements()) {
auto elementValue = extractCompileTimeValue(elementExpr);
auto elementValue = extractCompileTimeValue(elementExpr, declContext);
if (isa<TupleValue>(elementValue.get())) {
tuples.push_back(std::static_pointer_cast<TupleValue>(elementValue));
}
@@ -279,13 +285,15 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
? std::nullopt
: std::optional<std::string>(elementName.str().str());
elements.push_back({label, elementExpr->getType(),
extractCompileTimeValue(elementExpr)});
elements.push_back(
{label, elementExpr->getType(),
extractCompileTimeValue(elementExpr, declContext)});
}
} else {
for (auto elementExpr : tupleExpr->getElements()) {
elements.push_back({std::nullopt, elementExpr->getType(),
extractCompileTimeValue(elementExpr)});
elements.push_back(
{std::nullopt, elementExpr->getType(),
extractCompileTimeValue(elementExpr, declContext)});
}
}
return std::make_shared<TupleValue>(elements);
@@ -301,13 +309,13 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
declRefExpr->getDecl()->getName().getBaseIdentifier().str().str();
std::vector<FunctionParameter> parameters =
extractFunctionArguments(callExpr->getArgs());
extractFunctionArguments(callExpr->getArgs(), declContext);
return std::make_shared<FunctionCallValue>(identifier, parameters);
}
if (functionKind == ExprKind::ConstructorRefCall) {
std::vector<FunctionParameter> parameters =
extractFunctionArguments(callExpr->getArgs());
extractFunctionArguments(callExpr->getArgs(), declContext);
return std::make_shared<InitCallValue>(callExpr->getType(), parameters);
}
@@ -320,7 +328,7 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
declRefExpr->getDecl()->getName().getBaseIdentifier().str().str();
std::vector<FunctionParameter> parameters =
extractFunctionArguments(callExpr->getArgs());
extractFunctionArguments(callExpr->getArgs(), declContext);
auto declRef = dotSyntaxCallExpr->getFn()->getReferencedDecl();
switch (declRef.getDecl()->getKind()) {
@@ -364,23 +372,23 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
case ExprKind::Erasure: {
auto erasureExpr = cast<ErasureExpr>(expr);
return extractCompileTimeValue(erasureExpr->getSubExpr());
return extractCompileTimeValue(erasureExpr->getSubExpr(), declContext);
}
case ExprKind::Paren: {
auto parenExpr = cast<ParenExpr>(expr);
return extractCompileTimeValue(parenExpr->getSubExpr());
return extractCompileTimeValue(parenExpr->getSubExpr(), declContext);
}
case ExprKind::PropertyWrapperValuePlaceholder: {
auto placeholderExpr = cast<PropertyWrapperValuePlaceholderExpr>(expr);
return extractCompileTimeValue(
placeholderExpr->getOriginalWrappedValue());
return extractCompileTimeValue(placeholderExpr->getOriginalWrappedValue(),
declContext);
}
case ExprKind::Coerce: {
auto coerceExpr = cast<CoerceExpr>(expr);
return extractCompileTimeValue(coerceExpr->getSubExpr());
return extractCompileTimeValue(coerceExpr->getSubExpr(), declContext);
}
case ExprKind::DotSelf: {
@@ -394,7 +402,8 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
case ExprKind::UnderlyingToOpaque: {
auto underlyingToOpaque = cast<UnderlyingToOpaqueExpr>(expr);
return extractCompileTimeValue(underlyingToOpaque->getSubExpr());
return extractCompileTimeValue(underlyingToOpaque->getSubExpr(),
declContext);
}
case ExprKind::DefaultArgument: {
@@ -445,12 +454,13 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
case ExprKind::InjectIntoOptional: {
auto injectIntoOptionalExpr = cast<InjectIntoOptionalExpr>(expr);
return extractCompileTimeValue(injectIntoOptionalExpr->getSubExpr());
return extractCompileTimeValue(injectIntoOptionalExpr->getSubExpr(),
declContext);
}
case ExprKind::Load: {
auto loadExpr = cast<LoadExpr>(expr);
return extractCompileTimeValue(loadExpr->getSubExpr());
return extractCompileTimeValue(loadExpr->getSubExpr(), declContext);
}
case ExprKind::MemberRef: {
@@ -474,7 +484,7 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
Ctx, [&](bool isInterpolation, CallExpr *segment) -> void {
auto arg = segment->getArgs()->get(0);
auto expr = arg.getExpr();
segments.push_back(extractCompileTimeValue(expr));
segments.push_back(extractCompileTimeValue(expr, declContext));
});
return std::make_shared<InterpolatedStringLiteralValue>(segments);
@@ -483,7 +493,8 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
case ExprKind::Closure: {
auto closureExpr = cast<ClosureExpr>(expr);
auto body = closureExpr->getBody();
auto resultBuilderMembers = getResultBuilderMembersFromBraceStmt(body);
auto resultBuilderMembers =
getResultBuilderMembersFromBraceStmt(body, declContext);
if (!resultBuilderMembers.empty()) {
return std::make_shared<BuilderValue>(resultBuilderMembers);
@@ -493,7 +504,7 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
case ExprKind::DerivedToBase: {
auto derivedExpr = cast<DerivedToBaseExpr>(expr);
return extractCompileTimeValue(derivedExpr->getSubExpr());
return extractCompileTimeValue(derivedExpr->getSubExpr(), declContext);
}
default: {
break;
@@ -504,8 +515,8 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
return std::make_shared<RuntimeValue>();
}
static CustomAttrValue
extractAttributeValue(const CustomAttr *attr) {
static CustomAttrValue extractAttributeValue(const CustomAttr *attr,
const DeclContext *declContext) {
std::vector<FunctionParameter> parameters;
if (const auto *args = attr->getArgs()) {
for (auto arg : *args) {
@@ -518,8 +529,8 @@ extractAttributeValue(const CustomAttr *attr) {
argExpr = decl->getTypeCheckedDefaultExpr();
}
}
parameters.push_back(
{label, argExpr->getType(), extractCompileTimeValue(argExpr)});
parameters.push_back({label, argExpr->getType(),
extractCompileTimeValue(argExpr, declContext)});
}
}
return {attr, parameters};
@@ -529,7 +540,8 @@ static AttrValueVector
extractPropertyWrapperAttrValues(VarDecl *propertyDecl) {
AttrValueVector customAttrValues;
for (auto *propertyWrapper : propertyDecl->getAttachedPropertyWrappers())
customAttrValues.push_back(extractAttributeValue(propertyWrapper));
customAttrValues.push_back(
extractAttributeValue(propertyWrapper, propertyDecl->getDeclContext()));
return customAttrValues;
}
@@ -541,7 +553,9 @@ extractTypePropertyInfo(VarDecl *propertyDecl) {
if (const auto binding = propertyDecl->getParentPatternBinding()) {
if (const auto originalInit = binding->getInit(0)) {
return {propertyDecl, extractCompileTimeValue(originalInit),
return {propertyDecl,
extractCompileTimeValue(originalInit,
propertyDecl->getInnermostDeclContext()),
propertyWrapperValues};
}
}
@@ -551,9 +565,11 @@ extractTypePropertyInfo(VarDecl *propertyDecl) {
auto node = body->getFirstElement();
if (auto *stmt = node.dyn_cast<Stmt *>()) {
if (stmt->getKind() == StmtKind::Return) {
return {propertyDecl,
extractCompileTimeValue(cast<ReturnStmt>(stmt)->getResult()),
propertyWrapperValues};
return {
propertyDecl,
extractCompileTimeValue(cast<ReturnStmt>(stmt)->getResult(),
accessorDecl->getInnermostDeclContext()),
propertyWrapperValues};
}
}
}
@@ -992,7 +1008,8 @@ getResultBuilderElementFromASTNode(const ASTNode node) {
if (auto *D = node.dyn_cast<Decl *>()) {
if (auto *patternBinding = dyn_cast<PatternBindingDecl>(D)) {
if (auto originalInit = patternBinding->getOriginalInit(0)) {
return extractCompileTimeValue(originalInit);
return extractCompileTimeValue(
originalInit, patternBinding->getInnermostDeclContext());
}
}
}
@@ -1000,8 +1017,10 @@ getResultBuilderElementFromASTNode(const ASTNode node) {
}
BuilderValue::ConditionalMember
getConditionalMemberFromIfStmt(const IfStmt *ifStmt) {
std::vector<AvailabilitySpec> AvailabilityAttributes;
getConditionalMemberFromIfStmt(const IfStmt *ifStmt,
const DeclContext *declContext) {
std::vector<BuilderValue::ConditionalMember::AvailabilitySpec>
AvailabilitySpecs;
std::vector<std::shared_ptr<BuilderValue::BuilderMember>> IfElements;
std::vector<std::shared_ptr<BuilderValue::BuilderMember>> ElseElements;
if (auto thenBraceStmt = ifStmt->getThenStmt()) {
@@ -1016,7 +1035,7 @@ getConditionalMemberFromIfStmt(const IfStmt *ifStmt) {
if (auto elseStmt = ifStmt->getElseStmt()) {
if (auto *elseIfStmt = dyn_cast<IfStmt>(elseStmt)) {
ElseElements.push_back(std::make_shared<BuilderValue::ConditionalMember>(
getConditionalMemberFromIfStmt(elseIfStmt)));
getConditionalMemberFromIfStmt(elseIfStmt, declContext)));
} else if (auto *elseBraceStmt = dyn_cast<BraceStmt>(elseStmt)) {
for (auto elem : elseBraceStmt->getElements()) {
if (auto memberElement = getResultBuilderElementFromASTNode(elem)) {
@@ -1035,7 +1054,9 @@ getConditionalMemberFromIfStmt(const IfStmt *ifStmt) {
if (elt.getKind() == StmtConditionElement::CK_Availability) {
for (auto *Q : elt.getAvailability()->getQueries()) {
if (Q->getPlatform() != PlatformKind::none) {
AvailabilityAttributes.push_back(*Q);
auto spec = BuilderValue::ConditionalMember::AvailabilitySpec(
*Q->getDomain(), Q->getVersion());
AvailabilitySpecs.push_back(spec);
}
}
memberKind = BuilderValue::LimitedAvailability;
@@ -1043,12 +1064,12 @@ getConditionalMemberFromIfStmt(const IfStmt *ifStmt) {
}
}
if (AvailabilityAttributes.empty()) {
if (AvailabilitySpecs.empty()) {
return BuilderValue::ConditionalMember(memberKind, IfElements,
ElseElements);
}
return BuilderValue::ConditionalMember(memberKind, AvailabilityAttributes,
return BuilderValue::ConditionalMember(memberKind, AvailabilitySpecs,
IfElements, ElseElements);
}
@@ -1067,7 +1088,8 @@ getBuildArrayMemberFromForEachStmt(const ForEachStmt *forEachStmt) {
}
std::vector<std::shared_ptr<BuilderValue::BuilderMember>>
getResultBuilderMembersFromBraceStmt(BraceStmt *braceStmt) {
getResultBuilderMembersFromBraceStmt(BraceStmt *braceStmt,
const DeclContext *declContext) {
std::vector<std::shared_ptr<BuilderValue::BuilderMember>>
ResultBuilderMembers;
for (auto elem : braceStmt->getElements()) {
@@ -1079,7 +1101,7 @@ getResultBuilderMembersFromBraceStmt(BraceStmt *braceStmt) {
if (auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
ResultBuilderMembers.push_back(
std::make_shared<BuilderValue::ConditionalMember>(
getConditionalMemberFromIfStmt(ifStmt)));
getConditionalMemberFromIfStmt(ifStmt, declContext)));
} else if (auto *doStmt = dyn_cast<DoStmt>(stmt)) {
if (auto body = doStmt->getBody()) {
for (auto elem : body->getElements()) {
@@ -1106,7 +1128,8 @@ createBuilderCompileTimeValue(CustomAttr *AttachedResultBuilder,
if (!VarDecl->getAllAccessors().empty()) {
if (auto accessor = VarDecl->getAllAccessors()[0]) {
if (auto braceStmt = accessor->getTypecheckedBody()) {
ResultBuilderMembers = getResultBuilderMembersFromBraceStmt(braceStmt);
ResultBuilderMembers = getResultBuilderMembersFromBraceStmt(
braceStmt, accessor->getDeclContext());
}
}
}
@@ -1159,12 +1182,13 @@ void writeBuilderMember(
default: {
auto member = cast<BuilderValue::ConditionalMember>(Member);
if (auto availabilityAttributes = member->getAvailabilityAttributes()) {
if (auto availabilitySpecs = member->getAvailabilitySpecs()) {
JSON.attributeArray("availabilityAttributes", [&] {
for (auto elem : *availabilityAttributes) {
for (auto elem : *availabilitySpecs) {
JSON.object([&] {
JSON.attribute("platform",
platformString(elem.getPlatform()).str());
JSON.attribute(
"platform",
platformString(elem.getDomain().getPlatformKind()).str());
JSON.attribute("minVersion", elem.getVersion().getAsString());
});
}