Optimization: Also decide writing traces for parameter variables faster

This commit is contained in:
Kay Hayen
2025-11-17 20:57:05 +01:00
parent c2482c7a6d
commit ec8e28efef
4 changed files with 89 additions and 77 deletions

View File

@@ -175,12 +175,22 @@ class Variable(getMetaClassBase("Variable", require_slots=True)):
return self.traces return self.traces
def hasEmptyTracesFor(self, owner): def hasEmptyTracesFor(self, owner):
"""Do these traces contain any usage."""
for trace in self.traces: for trace in self.traces:
if trace.owner is owner and trace.isUsingTrace(): if trace.owner is owner and trace.isUsingTrace():
return False return False
return True return True
def hasNoWritingTraces(self):
"""Do these traces contain any writes."""
for trace in self.traces:
if trace.isWritingTrace():
return False
return True
def updateUsageState(self): def updateUsageState(self):
writers = set() writers = set()
users = set() users = set()
@@ -191,7 +201,7 @@ class Variable(getMetaClassBase("Variable", require_slots=True)):
if trace.isAssignTrace(): if trace.isAssignTrace():
writers.add(owner) writers.add(owner)
elif trace.isDeletedTrace() and owner is not self.owner: elif owner is not self.owner and trace.isDeletedTrace():
writers.add(owner) writers.add(owner)
self.writers = writers self.writers = writers

View File

@@ -416,7 +416,7 @@ class ExpressionFunctionBodyBase(
def optimizeVariableReleases(self): def optimizeVariableReleases(self):
for parameter_variable in self.getParameterVariablesWithManualRelease(): for parameter_variable in self.getParameterVariablesWithManualRelease():
read_only = self.trace_collection.hasReadOnlyTraces(parameter_variable) read_only = parameter_variable.hasNoWritingTraces()
if read_only: if read_only:
self.trace_collection.signalChange( self.trace_collection.signalChange(

View File

@@ -88,17 +88,6 @@ class CollectionUpdateMixin(object):
def getVariableTrace(self, variable, version): def getVariableTrace(self, variable, version):
return self.variable_traces[(variable, version)] return self.variable_traces[(variable, version)]
def getVariableTraces(self, variable):
result = []
for key, variable_trace in iterItems(self.variable_traces):
candidate = key[0]
if variable is candidate:
result.append(variable_trace)
return result
def getVariableTracesAll(self): def getVariableTracesAll(self):
return self.variable_traces return self.variable_traces
@@ -208,11 +197,7 @@ class CollectionStartPointMixin(CollectionUpdateMixin):
def getExceptionRaiseCollections(self): def getExceptionRaiseCollections(self):
return self.exception_collections return self.exception_collections
def hasReadOnlyTraces(self, variable): # TODO: Eliminate this function, call it directly where this is used.
# TODO: Combine these steps into one for performance gains.
traces = self.getVariableTraces(variable)
return areReadOnlyTraces(traces)
def updateVariablesFromCollection(self, old_collection, source_ref): def updateVariablesFromCollection(self, old_collection, source_ref):
Variables.updateVariablesFromCollection(old_collection, self, source_ref) Variables.updateVariablesFromCollection(old_collection, self, source_ref)
@@ -384,7 +369,7 @@ class TraceCollectionBase(object):
self.addVariableTrace( self.addVariableTrace(
variable, variable,
version, version,
ValueTraceEscaped(owner=self.owner, previous=current), ValueTraceEscaped(self.owner, current),
) )
self.markCurrentVariableTrace(variable, version) self.markCurrentVariableTrace(variable, version)
@@ -401,22 +386,25 @@ class TraceCollectionBase(object):
self.addVariableTrace( self.addVariableTrace(
variable, variable,
version, version,
ValueTraceUnknown(owner=self.owner, previous=current), ValueTraceUnknown(self.owner, current),
) )
self.markCurrentVariableTrace(variable, version) self.markCurrentVariableTrace(variable, version)
def markActiveVariableAsUnknown(self, variable): def markActiveVariableAsUnknown(self, variable):
current = self.getVariableCurrentTrace(variable) version = self.variable_actives[variable]
current = self.getVariableTrace(variable, version)
if not current.isUnknownOrVeryTrustedTrace(): if not current.isUnknownOrVeryTrustedTrace():
version = variable.allocateTargetNumber() # Unknown traces are div 3 rem 2.
version = version // 3 * 3 + 2
self.addVariableTrace( if not self.hasVariableTrace(variable, version):
variable, self.addVariableTrace(
version, variable,
ValueTraceUnknown(owner=self.owner, previous=current), version,
) ValueTraceUnknown(self.owner, current),
)
self.markCurrentVariableTrace(variable, version) self.markCurrentVariableTrace(variable, version)
@@ -487,9 +475,9 @@ class TraceCollectionBase(object):
def onVariableSet(self, variable, version, assign_node): def onVariableSet(self, variable, version, assign_node):
variable_trace = ValueTraceAssign( variable_trace = ValueTraceAssign(
owner=self.owner, self.owner,
assign_node=assign_node, assign_node,
previous=self.getVariableCurrentTrace(variable), self.getVariableCurrentTrace(variable),
) )
self.addVariableTrace(variable, version, variable_trace) self.addVariableTrace(variable, version, variable_trace)
@@ -521,9 +509,9 @@ class TraceCollectionBase(object):
def onVariableSetToUnescapableValue(self, variable, version, assign_node): def onVariableSetToUnescapableValue(self, variable, version, assign_node):
variable_trace = ValueTraceAssignUnescapable( variable_trace = ValueTraceAssignUnescapable(
owner=self.owner, self.owner,
assign_node=assign_node, assign_node,
previous=self.getVariableCurrentTrace(variable), self.getVariableCurrentTrace(variable),
) )
self.addVariableTrace(variable, version, variable_trace) self.addVariableTrace(variable, version, variable_trace)
@@ -535,9 +523,9 @@ class TraceCollectionBase(object):
def onVariableSetToVeryTrustedValue(self, variable, version, assign_node): def onVariableSetToVeryTrustedValue(self, variable, version, assign_node):
variable_trace = ValueTraceAssignVeryTrusted( variable_trace = ValueTraceAssignVeryTrusted(
owner=self.owner, self.owner,
assign_node=assign_node, assign_node,
previous=self.getVariableCurrentTrace(variable), self.getVariableCurrentTrace(variable),
) )
self.addVariableTrace(variable, version, variable_trace) self.addVariableTrace(variable, version, variable_trace)
@@ -551,10 +539,10 @@ class TraceCollectionBase(object):
self, variable, version, assign_node, replacement self, variable, version, assign_node, replacement
): ):
variable_trace = ValueTraceAssignUnescapablePropagated( variable_trace = ValueTraceAssignUnescapablePropagated(
owner=self.owner, self.owner,
assign_node=assign_node, assign_node,
previous=self.getVariableCurrentTrace(variable), self.getVariableCurrentTrace(variable),
replacement=replacement, replacement,
) )
self.addVariableTrace(variable, version, variable_trace) self.addVariableTrace(variable, version, variable_trace)
@@ -572,7 +560,9 @@ class TraceCollectionBase(object):
# TODO: Annotate value content as escaped. # TODO: Annotate value content as escaped.
variable_trace = ValueTraceDeleted( variable_trace = ValueTraceDeleted(
owner=self.owner, del_node=del_node, previous=old_trace self.owner,
old_trace,
del_node,
) )
# Assign to not initialized again. # Assign to not initialized again.
@@ -833,6 +823,8 @@ class TraceCollectionBase(object):
with TimerReport( with TimerReport(
message="Running merge for %s took %%.2f seconds" % collections, message="Running merge for %s took %%.2f seconds" % collections,
decider=False, decider=False,
include_sleep_time=False,
use_perf_counters=False,
): ):
new_actives = {} new_actives = {}
@@ -967,14 +959,14 @@ class TraceCollectionBase(object):
) )
def initVariableUnknown(self, variable): def initVariableUnknown(self, variable):
trace = ValueTraceUnknown(owner=self.owner, previous=None) trace = ValueTraceUnknown(self.owner, None)
self.addVariableTrace(variable, 0, trace) self.addVariableTrace(variable, 0, trace)
return trace return trace
def initVariableModule(self, variable): def initVariableModule(self, variable):
trace = ValueTraceUnknown(owner=self.owner, previous=None) trace = ValueTraceUnknown(self.owner, None)
self.addVariableTrace(variable, 0, trace) self.addVariableTrace(variable, 0, trace)
@@ -1002,7 +994,7 @@ class TraceCollectionBase(object):
return trace return trace
def initVariableUninitialized(self, variable): def initVariableUninitialized(self, variable):
trace = ValueTraceUninitialized(owner=self.owner, previous=None) trace = ValueTraceUninitialized(self.owner, None)
self.addVariableTrace(variable, 0, trace) self.addVariableTrace(variable, 0, trace)
@@ -1152,7 +1144,9 @@ class TraceCollectionFunction(CollectionStartPointMixin, TraceCollectionBase):
return TraceCollectionBase.initVariableModule(self, variable) return TraceCollectionBase.initVariableModule(self, variable)
assign_trace = ValueTraceAssignVeryTrusted( assign_trace = ValueTraceAssignVeryTrusted(
self.owner, assign_node=trusted_node.getParent(), previous=None self.owner,
trusted_node.getParent(),
None,
) )
self.addVariableTrace(variable, 0, assign_trace) self.addVariableTrace(variable, 0, assign_trace)
@@ -1254,36 +1248,6 @@ class TraceCollectionModule(CollectionStartPointMixin, TraceCollectionBase):
self.distribution_names[distribution_name] = success self.distribution_names[distribution_name] = success
def areReadOnlyTraces(variable_traces):
"""Do these traces contain any writes."""
# Many cases immediately return, that is how we do it here,
for variable_trace in variable_traces:
if variable_trace.isAssignTrace():
return False
elif variable_trace.isInitTrace():
pass
elif variable_trace.isDeletedTrace():
# A "del" statement can do this, and needs to prevent variable
# from being not released.
return False
elif variable_trace.isUninitializedTrace():
pass
elif variable_trace.isUnknownTrace():
return False
elif variable_trace.isEscapeTrace():
pass
elif variable_trace.isMergeTrace():
pass
elif variable_trace.isLoopTrace():
pass
else:
assert False, variable_trace
return True
def _checkActivesDiff(a1, a2): def _checkActivesDiff(a1, a2):
result = True result = True
for variable in a1: for variable in a1:

View File

@@ -134,6 +134,10 @@ class ValueTraceBase(object):
def isUsingTrace(self): def isUsingTrace(self):
"""Is the trace indicating a usage of the variable.""" """Is the trace indicating a usage of the variable."""
@abstractmethod
def isWritingTrace(self):
"""Is the trace indicating a usage of the variable."""
@staticmethod @staticmethod
def isAssignTrace(): def isAssignTrace():
return False return False
@@ -310,6 +314,10 @@ class ValueTraceUninitialized(ValueTraceUnassignedBase):
def isUsingTrace(self): def isUsingTrace(self):
return self.usage_count return self.usage_count
@staticmethod
def isWritingTrace():
return False
@staticmethod @staticmethod
def isTraceThatNeedsEscape(): def isTraceThatNeedsEscape():
return False return False
@@ -336,6 +344,10 @@ class ValueTraceDeleted(ValueTraceUnassignedBase):
def isUsingTrace(): def isUsingTrace():
return True return True
@staticmethod
def isWritingTrace():
return True
def getDelNode(self): def getDelNode(self):
return self.del_node return self.del_node
@@ -366,6 +378,10 @@ class ValueTraceInit(ValueTraceBase):
def isUsingTrace(): def isUsingTrace():
return True return True
@staticmethod
def isWritingTrace():
return False
@staticmethod @staticmethod
def mustHaveValue(): def mustHaveValue():
return True return True
@@ -438,6 +454,10 @@ class ValueTraceUnknown(ValueTraceBase):
def isUsingTrace(self): def isUsingTrace(self):
return self.usage_count return self.usage_count
@staticmethod
def isWritingTrace():
return True
@staticmethod @staticmethod
def isUnknownOrVeryTrustedTrace(): def isUnknownOrVeryTrustedTrace():
return True return True
@@ -518,6 +538,10 @@ class ValueTraceEscaped(ValueTraceUnknown):
def isEscapeTrace(): def isEscapeTrace():
return True return True
@staticmethod
def isWritingTrace():
return False
def isUsingTrace(self): def isUsingTrace(self):
return self.usage_count return self.usage_count
@@ -568,6 +592,10 @@ class ValueTraceAssign(ValueTraceBase):
def isUsingTrace(): def isUsingTrace():
return True return True
@staticmethod
def isWritingTrace():
return True
def compareValueTrace(self, other): def compareValueTrace(self, other):
return other.isAssignTrace() and self.assign_node is other.assign_node return other.isAssignTrace() and self.assign_node is other.assign_node
@@ -673,6 +701,8 @@ class ValueTraceAssignUnescapablePropagated(ValueTraceAssignUnescapable):
class ValueTraceMergeBase(ValueTraceBase): class ValueTraceMergeBase(ValueTraceBase):
"""Merge of two or more traces or start of loops.""" """Merge of two or more traces or start of loops."""
# Base classes can be abstract, pylint: disable=I0021,abstract-method
__slots__ = () __slots__ = ()
def addNameUsage(self): def addNameUsage(self):
@@ -722,9 +752,9 @@ class ValueTraceMerge(ValueTraceMergeBase):
if trace not in shorted: if trace not in shorted:
shorted.append(trace) shorted.append(trace)
traces = tuple(shorted) assert len(shorted) > 1, traces
assert len(traces) > 1 traces = tuple(shorted)
# assert len(set(traces)) == len(traces), [(v) for v in traces] # assert len(set(traces)) == len(traces), [(v) for v in traces]
@@ -775,6 +805,10 @@ class ValueTraceMerge(ValueTraceMergeBase):
# Checking definite is enough, the merges, we shall see them as well. # Checking definite is enough, the merges, we shall see them as well.
return self.usage_count return self.usage_count
@staticmethod
def isWritingTrace():
return False
def compareValueTrace(self, other): def compareValueTrace(self, other):
if not other.isMergeTrace(): if not other.isMergeTrace():
return False return False
@@ -866,6 +900,10 @@ class ValueTraceLoopBase(ValueTraceMergeBase):
def isUsingTrace(): def isUsingTrace():
return True return True
@staticmethod
def isWritingTrace():
return False
def getTypeShape(self): def getTypeShape(self):
if self.type_shape is None: if self.type_shape is None:
if len(self.type_shapes) > 1: if len(self.type_shapes) > 1: