[Libre-soc-dev] proposed better parallel tree reduction algorithm for spec

Jacob Lifshay programmerjake at gmail.com
Mon Sep 5 15:04:32 BST 2022


the new algorithm doesn't need any new SPRs or larger SVSTEP
srcstep/dststep fields, it also sets the mask to only have the mask bit
corresponding to the destination element set, so a simple predicated scalar
svp64 move after the reduction instruction can move the result into
whatever register you want. it also doesn't need special remap handling. it
is also interruptible and resumes correctly after any number of
single-element operations.

https://godbolt.org/z/oYx3sdfG8

import itertools
# the max value that the hardware supports VL having must be a
# power of 2 -- LOG2_HW_VL_LIMIT is the log2 of that max value
LOG2_HW_VL_LIMIT = 6

def find(mask, remap, VL, range_):
    for i in range_:
        if i >= VL:
            break
        if mask[remap[i]]:
            return i
    return None

def reduce_elm(gprs, remap, mask, src1, src2):
    mask[remap[src2]] = False
    gprs[remap[src1]] += gprs[remap[src2]]

def tree_reduce(gprs, remap, mask, VL):
    # log2_half_step doesn't need to be stored in an SPR,
    # this algorithm can be restarted after any number
    # of reduce_elm calls and it will resume in the correct spot
    for log2_half_step in range(LOG2_HW_VL_LIMIT):
        half_step = 1 << log2_half_step
        step = 2 * half_step
        # step_ctr is src/dststep in SVSTATE
        for step_ctr in range(0, VL, step):
            # all the finds in this algorithm can be combined into a tree
            # of similar cost to 1 trailing-zero-count over the remapped
input mask

            # find here is kinda like svp64 predicated scalar-dest
            src1 = find(mask, remap, VL, range(step_ctr, step_ctr +
half_step))
            src2 = find(mask, remap, VL, range(step_ctr + half_step,
step_ctr + step))
            if src1 is not None and src2 is not None:
                reduce_elm(gprs, remap, mask, src1, src2)

# test/demo
for VL in range(5):
    itr = itertools.product((False, True), repeat=VL)
    if VL == 0:
        itr = [[]]  # actually test VL = 0
    for orig_mask in itr:
        mask = list(orig_mask)
        gprs = [chr(ord("A") + i) for i in range(VL)]
        expected = ""
        for i in range(VL):
            if mask[i]:
                expected += gprs[i]
        remap = list(range(VL))
        print("\nVL=", VL, "mask=", mask, "gprs=", gprs, "remap=", remap)
        tree_reduce(gprs, remap, mask, VL)
        print("VL=", VL, "mask=", mask, "gprs=", gprs, "remap=", remap)
        if expected == "":
            assert not any(mask)
        else:
            result = find(mask, remap, VL, range(VL))
            assert result is not None
            assert not any(mask[result + 1:]), "too many mask elements left
set"
            assert gprs[result] == expected

Jacob


More information about the Libre-soc-dev mailing list