Lua Cartesian Product

Date: 2023-05-15
function dump(o)
    if type(o) == 'table' then
        local s = '{ '
        for k, v in pairs(o) do
            if type(k) ~= 'number' then k = '"' .. k .. '"' end
            s = s .. '[' .. k .. ']=' .. dump(v) .. ', '
        end
        return s .. '} '
    else
        return tostring(o)
    end
end

function cartesian(arrays)
    local insert, unpack = table.insert, table.unpack
    local wrap, yield = coroutine.wrap, coroutine.yield
    return wrap(function()
        local function _cartesian(arrs, current, result)
            if current > #arrs then
                yield(result)
            else
                for _, value in ipairs(arrs[current]) do
                    local new_result = {}
                    for _, v in ipairs(result) do
                        insert(new_result, v)
                    end
                    insert(new_result, value)
                    _cartesian(arrs, current + 1, new_result)
                end
            end
        end
        _cartesian(arrays, 1, {})
    end)
end

local arrs = {
    {'A', 'B', 'C', 'D', 'E', 'F'},
    {1, 2, 3, 4, 5,6,7,8,9, 10, 11, 12, 13,14,15},
    {'I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'VIIII' }
}


local start = os.clock();
local count = 0
for a in cartesian(arrs) do
    count = count + 1
    if (count % 50 == 0) then
        print(dump(a))
    end
end
local elapsed = os.clock() - start
print("elapsed:", elapsed, count)

Another version

local function cartesian_product(sets)
  local result = {}
  local set_count = #sets
  local yield = coroutine.yield 
  function descend(depth)
    if depth == set_count then
      for k,v in pairs(sets[depth]) do
        result[depth] = v
        yield(result)
      end
    else
      for k,v in pairs(sets[depth]) do
        result[depth] = v
        descend(depth + 1)
      end
    end
  end
  return coroutine.wrap(function() descend(1) end)
end

Test code

function dump(o)
    if type(o) == 'table' then
        local s = '{ '
        for k, v in pairs(o) do
            if type(k) ~= 'number' then k = '"' .. k .. '"' end
            s = s .. '[' .. k .. ']=' .. dump(v) .. ', '
        end
        return s .. '} '
    else
        return tostring(o)
    end
end

function cartesian(arrays)
    local insert, unpack = table.insert, table.unpack
    local wrap, yield = coroutine.wrap, coroutine.yield
    return wrap(function()
        local function _cartesian(arrs, current, result)
            if current > #arrs then
                yield(result)
            else
                for _, value in ipairs(arrs[current]) do
                    local new_result = {}
                    for _, v in ipairs(result) do
                        insert(new_result, v)
                    end
                    insert(new_result, value)
                    _cartesian(arrs, current + 1, new_result)
                end
            end
        end
        _cartesian(arrays, 1, {})
    end)
end

local arrs = {
    {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'},
    {1, 2, 3, 4, 5,6,7,8, 9, 10, 11, 12, 13,14,15,16,17,18,19,20,21,22,23,24},
    {'I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'VIIII', 'X' }
}

local function method1() 
    local start = os.clock();
    local count = 0
    for a in cartesian(arrs) do
        count = count + 1
        if (count % 50 == 0) then
            print(dump(a))
        end
    end
    local elapsed = os.clock() - start
    print("elapsed:", elapsed, count)
end


local function cartesian_product(sets)
    local result = {}
    local set_count = #sets
    local yield = coroutine.yield 
    function descend(depth)
      if depth == set_count then
        for k,v in pairs(sets[depth]) do
          result[depth] = v
          yield(result)
        end
      else
        for k,v in pairs(sets[depth]) do
          result[depth] = v
          descend(depth + 1)
        end
      end
    end
    return coroutine.wrap(function() descend(1) end)
  end

local function method2() 
    local start = os.clock();
    local count = 0
    for a in cartesian(arrs) do
        count = count + 1
        if (count % 50 == 0) then
            print(dump(a))
        end
    end
    local elapsed = os.clock() - start
    print("elapsed:", elapsed, count)
end

method1()
method2()
78060cookie-checkLua Cartesian Product