Optimization: Detect variable references discarded sooner

* The expression was already adding its usage indication,
  undo that such that next micro pass knows if it's used
  or not sooner. Otherwise it may not happen if that will
  have no other changes. And if it did, it will need even
  one more than necessary.
This commit is contained in:
Kay Hayen
2025-11-28 12:22:31 +00:00
parent 1fe435085b
commit 6ac5c8b571
7 changed files with 1943 additions and 2 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -716,6 +716,7 @@ branches.""",
Both branches have no effect, reduced to evaluate condition.""",
)
else:
condition.undoComputeExpressionRaw(trace_collection)
self.finalize()
return (

View File

@@ -153,7 +153,7 @@ class ExpressionBase(NodeBase):
@staticmethod
def extractUnhashableNodeType():
"""Return the value type that is not hashable, if isKnowtoBeHashable() returns False."""
"""Return the value type that is not hashable, if isKnownToBeHashable() returns False."""
# Not available by default.
return None
@@ -182,6 +182,11 @@ class ExpressionBase(NodeBase):
not overload this unless necessary.
"""
@staticmethod
def undoComputeExpressionRaw(trace_collection):
# Virtual method
pass
def computeExpressionAttribute(self, lookup_node, attribute_name, trace_collection):
# By default, an attribute lookup may change everything about the lookup
# source.

View File

@@ -34,7 +34,6 @@ spell-checker: ignore winmode zfill
# Loop unrolling over child names, pylint: disable=too-many-branches
from abc import abstractmethod
from .ExpressionBases import ExpressionBase
@@ -62,6 +61,17 @@ class NoChildHavingFinalNoRaiseMixin(ExpressionBase):
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -141,6 +151,17 @@ class NoChildHavingFinalNoRaiseNameMixin(ExpressionBase):
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -259,6 +280,17 @@ class ChildHavingArgsTupleFinalNoRaiseMixin(ExpressionBase):
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -442,6 +474,17 @@ class ChildrenHavingArgsTupleNameOptionalObjOptionalFinalNoRaiseForRaiseMixin(
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -644,6 +687,17 @@ class ChildrenHavingArgsTupleNameOptionalPathOptionalFinalNoRaiseForRaiseMixin(
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -796,6 +850,17 @@ class ChildrenHavingCallableArgSentinelFinalMixin(ExpressionBase):
trace_collection.onExceptionRaiseExit(BaseException)
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
def collectVariableAccesses(self, emit_variable):
"""Collect variable reads and writes of child nodes."""
@@ -877,6 +942,17 @@ class ChildHavingDistributionNameFinalChildrenMixin(ExpressionBase):
# Then ask ourselves to work on it.
return self.computeExpression(trace_collection)
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@abstractmethod
def computeExpression(self, trace_collection):
"""Must be overloaded for non-final node."""
@@ -999,6 +1075,17 @@ class ChildHavingElementsTupleFinalNoRaiseMixin(ExpressionBase):
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -1112,6 +1199,17 @@ class ChildHavingExpressionAttributeNameMixin(ExpressionBase):
# Then ask ourselves to work on it.
return self.computeExpression(trace_collection)
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@abstractmethod
def computeExpression(self, trace_collection):
"""Must be overloaded for non-final node."""
@@ -1244,6 +1342,17 @@ class ChildrenHavingExpressionNameRaiseWaitConstantNameMixin(ExpressionBase):
# Then ask ourselves to work on it.
return self.computeExpression(trace_collection)
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@abstractmethod
def computeExpression(self, trace_collection):
"""Must be overloaded for non-final node."""
@@ -1373,6 +1482,17 @@ class ChildHavingInterpolationsTupleStrValuesMixin(ExpressionBase):
# Then ask ourselves to work on it.
return self.computeExpression(trace_collection)
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@abstractmethod
def computeExpression(self, trace_collection):
"""Must be overloaded for non-final node."""
@@ -1498,6 +1618,17 @@ class ChildrenHavingLeftRightFinalNoRaiseMixin(ExpressionBase):
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -1599,6 +1730,17 @@ class ChildHavingListArgNoRaiseMixin(ExpressionBase):
# Then ask ourselves to work on it.
return self.computeExpression(trace_collection)
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -1732,6 +1874,17 @@ class ChildrenHavingListArgItemNoRaiseMixin(ExpressionBase):
# Then ask ourselves to work on it.
return self.computeExpression(trace_collection)
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -1866,6 +2019,17 @@ class ChildrenHavingListArgValueFinalNoRaiseMixin(ExpressionBase):
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -1987,6 +2151,17 @@ class ChildHavingPairsTupleFinalNoRaiseMixin(ExpressionBase):
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -2107,6 +2282,17 @@ class ChildHavingPromptOptionalFinalMixin(ExpressionBase):
trace_collection.onExceptionRaiseExit(BaseException)
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
def collectVariableAccesses(self, emit_variable):
"""Collect variable reads and writes of child nodes."""
@@ -2200,6 +2386,17 @@ class ChildHavingValueFinalNoRaiseMixin(ExpressionBase):
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@staticmethod
def mayRaiseExceptionOperation():
return False
@@ -2349,6 +2546,17 @@ class ChildrenHavingValueFormatSpecOptionalConversionStrValueMixin(ExpressionBas
# Then ask ourselves to work on it.
return self.computeExpression(trace_collection)
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
@abstractmethod
def computeExpression(self, trace_collection):
"""Must be overloaded for non-final node."""

View File

@@ -557,6 +557,9 @@ Replaced read-only module attribute '__spec__' with module attribute reference."
return self, None, None
def undoComputeExpressionRaw(self, trace_collection):
self.variable_trace.removeUsage()
def computeExpressionCall(self, call_node, call_args, call_kw, trace_collection):
if self.variable_trace is not None:
attribute_node = self.variable_trace.getAttributeNode()

View File

@@ -85,6 +85,9 @@ class ValueTraceBase(object):
def addUsage(self):
self.usage_count += 1
def removeUsage(self):
self.usage_count -= 1
def addNameUsage(self):
self.usage_count += 1
self.name_usage_count += 1
@@ -92,10 +95,21 @@ class ValueTraceBase(object):
if self.name_usage_count <= 2 and self.previous is not None:
self.previous.addNameUsage()
def removeNameUsage(self):
self.usage_count -= 1
self.name_usage_count -= 1
if self.name_usage_count < 2 and self.previous is not None:
self.previous.removeNameUsage()
def addMergeUsage(self):
self.usage_count += 1
self.merge_usage_count += 1
def removeMergeUsage(self):
self.usage_count -= 1
self.merge_usage_count -= 1
def getUsageCount(self):
return self.usage_count
@@ -283,14 +297,25 @@ class ValueTraceStartMixin(object):
def addUsage(self):
self.usage_count += 1
def removeUsage(self):
self.usage_count -= 1
def addMergeUsage(self):
self.usage_count += 1
self.merge_usage_count += 1
def removeMergeUsage(self):
self.usage_count -= 1
self.merge_usage_count -= 1
def addNameUsage(self):
self.usage_count += 1
self.name_usage_count += 1
def removeNameUsage(self):
self.usage_count -= 1
self.name_usage_count -= 1
@staticmethod
def getAttributeNode():
return None
@@ -454,6 +479,12 @@ class ValueTraceUnknownBase(ValueTraceBase):
if self.previous:
self.previous.addUsage()
def removeUsage(self):
self.usage_count -= 1
if self.previous:
self.previous.removeUsage()
def addMergeUsage(self):
self.usage_count += 1
self.merge_usage_count += 1
@@ -461,6 +492,13 @@ class ValueTraceUnknownBase(ValueTraceBase):
if self.previous:
self.previous.addMergeUsage()
def removeMergeUsage(self):
self.usage_count -= 1
self.merge_usage_count -= 1
if self.previous:
self.previous.removeMergeUsage()
def compareValueTrace(self, other):
# We are unknown, just need to know if the other one is, pylint: disable=no-self-use
return other.isUnknownTrace()
@@ -542,6 +580,13 @@ class ValueTraceEscaped(ValueTraceUnknown):
if self.usage_count <= 2:
self.previous.addNameUsage()
def removeUsage(self):
self.usage_count -= 1
# The previous must be prevented from optimization if still used afterwards.
if self.usage_count < 2:
self.previous.removeNameUsage()
def addMergeUsage(self):
self.usage_count += 1
if self.usage_count <= 2:
@@ -551,6 +596,15 @@ class ValueTraceEscaped(ValueTraceUnknown):
if self.merge_usage_count <= 2:
self.previous.addMergeUsage()
def removeMergeUsage(self):
self.usage_count -= 1
if self.usage_count < 2:
self.previous.removeNameUsage()
self.merge_usage_count -= 1
if self.merge_usage_count < 2:
self.previous.removeMergeUsage()
def getTypeShape(self):
return self.previous.getTypeShape()
@@ -766,6 +820,14 @@ class ValueTraceMergeBase(ValueTraceBase):
for previous in self.previous:
previous.addNameUsage()
def removeNameUsage(self):
self.usage_count -= 1
self.name_usage_count -= 1
if self.name_usage_count < 2 and self.previous is not None:
for previous in self.previous:
previous.removeNameUsage()
def addUsage(self):
self.usage_count += 1
@@ -774,10 +836,22 @@ class ValueTraceMergeBase(ValueTraceBase):
for trace in self.previous:
trace.addMergeUsage()
def removeUsage(self):
self.usage_count -= 1
# Only do it once.
if self.usage_count == 0:
for trace in self.previous:
trace.removeMergeUsage()
def addMergeUsage(self):
self.addUsage()
self.merge_usage_count += 1
def removeMergeUsage(self):
self.removeUsage()
self.merge_usage_count -= 1
def isUsingTrace(self):
# Checking definite is enough, the merges, we shall see them as well.
return self.usage_count

View File

@@ -361,6 +361,17 @@ class {{mixin_name}}({{base_class}}):
return self.computeExpression(trace_collection)
{% endif %}
def undoComputeExpressionRaw(self, trace_collection):
for child in self.getVisitableNodes():
child.undoComputeExpressionRaw(trace_collection)
self.undoComputeExpression()
# For overload only
@staticmethod
def undoComputeExpression():
pass
{% if raise_mode == "no_raise" %}
@staticmethod
def mayRaiseExceptionOperation():