[Libre-soc-dev] proposed better parallel tree reduction algorithm for spec
Jacob Lifshay
programmerjake at gmail.com
Mon Sep 5 15:04:32 BST 2022
the new algorithm doesn't need any new SPRs or larger SVSTEP
srcstep/dststep fields, it also sets the mask to only have the mask bit
corresponding to the destination element set, so a simple predicated scalar
svp64 move after the reduction instruction can move the result into
whatever register you want. it also doesn't need special remap handling. it
is also interruptible and resumes correctly after any number of
single-element operations.
https://godbolt.org/z/oYx3sdfG8
import itertools
# the max value that the hardware supports VL having must be a
# power of 2 -- LOG2_HW_VL_LIMIT is the log2 of that max value
LOG2_HW_VL_LIMIT = 6
def find(mask, remap, VL, range_):
for i in range_:
if i >= VL:
break
if mask[remap[i]]:
return i
return None
def reduce_elm(gprs, remap, mask, src1, src2):
mask[remap[src2]] = False
gprs[remap[src1]] += gprs[remap[src2]]
def tree_reduce(gprs, remap, mask, VL):
# log2_half_step doesn't need to be stored in an SPR,
# this algorithm can be restarted after any number
# of reduce_elm calls and it will resume in the correct spot
for log2_half_step in range(LOG2_HW_VL_LIMIT):
half_step = 1 << log2_half_step
step = 2 * half_step
# step_ctr is src/dststep in SVSTATE
for step_ctr in range(0, VL, step):
# all the finds in this algorithm can be combined into a tree
# of similar cost to 1 trailing-zero-count over the remapped
input mask
# find here is kinda like svp64 predicated scalar-dest
src1 = find(mask, remap, VL, range(step_ctr, step_ctr +
half_step))
src2 = find(mask, remap, VL, range(step_ctr + half_step,
step_ctr + step))
if src1 is not None and src2 is not None:
reduce_elm(gprs, remap, mask, src1, src2)
# test/demo
for VL in range(5):
itr = itertools.product((False, True), repeat=VL)
if VL == 0:
itr = [[]] # actually test VL = 0
for orig_mask in itr:
mask = list(orig_mask)
gprs = [chr(ord("A") + i) for i in range(VL)]
expected = ""
for i in range(VL):
if mask[i]:
expected += gprs[i]
remap = list(range(VL))
print("\nVL=", VL, "mask=", mask, "gprs=", gprs, "remap=", remap)
tree_reduce(gprs, remap, mask, VL)
print("VL=", VL, "mask=", mask, "gprs=", gprs, "remap=", remap)
if expected == "":
assert not any(mask)
else:
result = find(mask, remap, VL, range(VL))
assert result is not None
assert not any(mask[result + 1:]), "too many mask elements left
set"
assert gprs[result] == expected
Jacob
More information about the Libre-soc-dev
mailing list