Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 65 additions & 40 deletions loopy/kernel/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import itertools
import logging
import sys
from collections.abc import Set as AbstractSet
from functools import reduce
from sys import intern
from typing import (
Expand All @@ -47,7 +46,7 @@
import pymbolic.primitives as p
from islpy import dim_type
from pymbolic import Expression
from pytools import fset_union, memoize_on_first_arg, natsorted, set_union
from pytools import fset_union, memoize_on_first_arg, natsorted

from loopy.diagnostic import LoopyError, warn_with_kernel
from loopy.kernel import LoopKernel
Expand All @@ -59,7 +58,7 @@
MultiAssignmentBase,
_DataObliviousInstruction,
)
from loopy.symbolic import CombineMapper
from loopy.symbolic import CombineMapper, Reduction
from loopy.translation_unit import (
CallableId,
CallablesTable,
Expand All @@ -70,7 +69,14 @@


if TYPE_CHECKING:
from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
from collections.abc import (
Callable,
Collection,
Iterable,
Mapping,
Sequence,
Set as AbstractSet,
)

from pymbolic import ArithmeticExpression
from pytools.tag import Tag
Expand Down Expand Up @@ -2246,68 +2252,87 @@ def get_hw_axis_base_for_codegen(kernel: LoopKernel, iname: str) -> isl.Aff:

# {{{ get access map from an instruction

def union_amaps(amaps: Sequence[isl.Map]):
import islpy as isl
return reduce(isl.Map.union, amaps[1:], amaps[0])


@dataclasses.dataclass
class _IndexCollector(CombineMapper[AbstractSet[tuple[Expression, ...]], []]):
class _InstructionAccessMapCollector(
CombineMapper[dict[frozenset[str], isl.Map], [isl.Set]]):
knl: LoopKernel
var: str

def __post_init__(self) -> None:
super().__init__()

@override
def combine(self,
values: Iterable[AbstractSet[tuple[Expression, ...]]]
) -> AbstractSet[tuple[Expression, ...]]:
return set_union(values)
def combine(
self,
values: Iterable[dict[frozenset[str], isl.Map]]
) -> dict[frozenset[str], isl.Map]:
result: dict[frozenset[str], isl.Map] = {}
for value in values:
for inames, amap in value.items():
try:
old_amap = result[inames]
except KeyError:
result[inames] = amap
else:
result[inames] = union_amaps((old_amap, amap))
return result

@override
def map_reduction(
self, expr: Reduction, domain: isl.Set) -> dict[frozenset[str], isl.Map]:
new_domain = self.knl.get_inames_domain(
frozenset(domain.get_var_dict(dim_type.set))
| frozenset(expr.inames)).to_set()
return super().map_reduction(expr, new_domain)

@override
def map_subscript(self, expr: p.Subscript) -> AbstractSet[tuple[Expression, ...]]:
def map_subscript(
self, expr: p.Subscript, domain: isl.Set) -> dict[frozenset[str], isl.Map]:
from loopy.symbolic import get_access_map
assert isinstance(expr.aggregate, p.Variable)
if expr.aggregate.name == self.var:
return (super().map_subscript(expr) | frozenset([expr.index_tuple]))
inames = frozenset(domain.get_var_dict(dim_type.set).keys())
amap = get_access_map(
domain, expr.index_tuple, self.knl.assumptions)
return self.combine([
super().map_subscript(expr, domain), {inames: amap}])
else:
return super().map_subscript(expr)
return super().map_subscript(expr, domain)

@override
def map_algebraic_leaf(
self, expr: p.AlgebraicLeaf,
) -> frozenset[tuple[Expression, ...]]:
return frozenset()
self, expr: p.AlgebraicLeaf, domain: isl.Set,
) -> dict[frozenset[str], isl.Map]:
return {}

@override
def map_constant(
self, expr: object
) -> frozenset[tuple[Expression, ...]]:
return frozenset()


def _union_amaps(amaps: Sequence[isl.Map]):
import islpy as isl
return reduce(isl.Map.union, amaps[1:], amaps[0])
self, expr: object, domain: isl.Set) -> dict[frozenset[str], isl.Map]:
return {}


def get_insn_access_map(kernel: LoopKernel, insn_id: str, var: str):
def get_insn_access_maps(
kernel: LoopKernel, insn_id: str, var: str) -> list[isl.Map]:
from loopy.match import Id
from loopy.symbolic import get_access_map
from loopy.transform.subst import expand_subst

insn = kernel.id_to_insn[insn_id]

kernel = expand_subst(kernel, within=Id(insn_id))
indices = tuple(
_IndexCollector(var)(
(insn.expression, insn.assignees, tuple(insn.predicates))
)
)

amaps = [
get_access_map(
kernel.get_inames_domain(insn.within_inames).to_set(),
idx, kernel.assumptions
)
for idx in indices
]
insn = kernel.id_to_insn[insn_id]
insn_inames = kernel.insn_inames(insn)
inames_domain = kernel.get_inames_domain(insn_inames)
domain = inames_domain.project_out_except(
insn_inames, [dim_type.set]).to_set()

inames_to_amap = _InstructionAccessMapCollector(kernel, var)(
(insn.expression, insn.assignees, tuple(insn.predicates)), domain)

return _union_amaps(amaps)
return list(inames_to_amap.values())

# }}}

Expand Down
26 changes: 17 additions & 9 deletions loopy/transform/loop_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,23 +390,31 @@ def _compute_isinfusible_via_access_map(
import pymbolic.primitives as prim

from loopy.diagnostic import UnableToDetermineAccessRangeError
from loopy.kernel.tools import get_insn_access_map
from loopy.kernel.tools import get_insn_access_maps
from loopy.symbolic import isl_set_from_expr

try:
amap_pred = get_insn_access_map(kernel, insn_pred, var)
amap_succ = get_insn_access_map(kernel, insn_succ, var)
amaps_pred = get_insn_access_maps(kernel, insn_pred, var)
amaps_succ = get_insn_access_maps(kernel, insn_succ, var)
except UnableToDetermineAccessRangeError:
# either predecessors or successors has a non-affine access i.e.
# fallback to the safer option => infusible
return True

amap_pred = amap_pred.project_out_except(
outer_inames | {candidate_pred}, [isl.dim_type.param, isl.dim_type.in_]
)
amap_succ = amap_succ.project_out_except(
outer_inames | {candidate_succ}, [isl.dim_type.param, isl.dim_type.in_]
)
amaps_pred = [
amap.project_out_except(
outer_inames | {candidate_pred}, [isl.dim_type.param, isl.dim_type.in_])
for amap in amaps_pred]
amaps_succ = [
amap.project_out_except(
outer_inames | {candidate_succ}, [isl.dim_type.param, isl.dim_type.in_])
for amap in amaps_succ]

# amaps should have the same space after projecting out the inner loops, so they
# can safely be unioned
from loopy.kernel.tools import union_amaps
amap_pred = union_amaps(amaps_pred)
amap_succ = union_amaps(amaps_succ)

# move outer inames to param
for outer_iname in sorted(outer_inames):
Expand Down
33 changes: 33 additions & 0 deletions test/test_loop_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,39 @@ def test_reduction_loop_fusion_with_multiple_redn_in_same_insn(
lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit.with_kernel(knl))


def test_loop_fusion_with_inner_reduction(ctx_factory: cl.CtxFactory):
ctx = ctx_factory()

t_unit = lp.make_kernel(
["{[i0, j0]: 0 <= i0, j0 < 10}",
"{[i1]: 0 <= i1 < 10}",
# Intentionally keeping j1 separate from i1 to test for regression. See
# https://github.com/inducer/loopy/pull/1009 for details.
"{[j1]: 0 <= j1 < 10}"],
"""
a[i0, j0] = j0 * 1.0 {id=insn1}
out[i1] = sum(j1, a[i1, j1]) {id=insn2}
""",
)
ref_t_unit = t_unit

knl = t_unit.default_entrypoint

fused_chunks = lp.get_kennedy_unweighted_fusion_candidates(
knl, frozenset(["i0", "i1"])
)
knl = lp.rename_inames_in_batch(knl, fused_chunks)

assert (
len(
knl.id_to_insn["insn1"].within_inames
& knl.id_to_insn["insn2"].within_inames
) == 1
)

lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit.with_kernel(knl))


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down
Loading