diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index 7fabaface..906e2ae2a 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -239,7 +239,7 @@ def __init__(self, rule_mapping_context, subst_name, subst_tag, within, access_descriptors, array_base_map, storage_axis_names, storage_axis_sources, non1_storage_axis_names, - temporary_name, compute_insn_id, compute_dep_id, + temporary_name, compute_dep_ids, compute_read_variables): super().__init__(rule_mapping_context) @@ -255,8 +255,7 @@ def __init__(self, rule_mapping_context, subst_name, subst_tag, within, self.non1_storage_axis_names = non1_storage_axis_names self.temporary_name = temporary_name - self.compute_insn_id = compute_insn_id - self.compute_dep_id = compute_dep_id + self.compute_dep_ids = compute_dep_ids self.compute_read_variables = compute_read_variables self.compute_insn_depends_on = set() @@ -339,7 +338,6 @@ def map_subst_rule(self, name, tag, arguments, expn_state): def map_kernel(self, kernel): new_insns = [] - excluded_insn_ids = {self.compute_insn_id, self.compute_dep_id} # precomputed_in_insns: set of insn ids in which the subst rule was # precomputed. precomputed_in_insns = set() @@ -359,20 +357,19 @@ def map_kernel(self, kernel): if self.replaced_something: insn = insn.copy( depends_on=( - insn.depends_on - | frozenset([self.compute_dep_id]))) + insn.depends_on | self.compute_dep_ids)) precomputed_in_insns.add(insn.id) for dep in insn.depends_on: - if dep in excluded_insn_ids: + if dep in self.compute_dep_ids: continue dep_insn = kernel.id_to_insn[dep] if (frozenset(dep_insn.assignee_var_names()) & self.compute_read_variables): self.compute_insn_depends_on.update( - insn.depends_on - excluded_insn_ids) + insn.depends_on - self.compute_dep_ids) new_insns.append(insn) @@ -967,7 +964,7 @@ def add_assumptions(d): expression=compute_expression, # within_inames determined below ) - compute_dep_id = compute_insn_id + compute_dep_ids = {compute_insn_id} added_compute_insns: list[InstructionBase] = [compute_insn] if temporary_address_space == AddressSpace.GLOBAL: @@ -979,7 +976,7 @@ def add_assumptions(d): depends_on=frozenset([compute_insn_id]), synchronization_kind="global", mem_kind="global") - compute_dep_id = barrier_insn_id + compute_dep_ids.add(barrier_insn_id) added_compute_insns.append(barrier_insn) @@ -995,7 +992,7 @@ def add_assumptions(d): access_descriptors, abm, storage_axis_names, storage_axis_sources, non1_storage_axis_names, - temporary_name, compute_insn_id, compute_dep_id, + temporary_name, frozenset(compute_dep_ids), compute_read_variables=get_dependencies(expander(compute_expression))) kernel = invr.map_kernel(kernel)