Utility: Cache lua function calls to skip repeated queries

Here’s a utility function you can use to cache the return value of any function. It’s useful for caching large queries you do multiple times, for when things shouldn’t change more often than in a single load of SB.

-- cache the value a function returns
local _call_cache = {}
local cache_fn_retval = function(name, fn)
  print('CACHE(prepare):', name)
  -- build a function
  return function(...)
    -- temp variable that's local to the lambda
    local _name = name
    local key = _name .. '(' .. table.concat({...}, ',') .. ')'
    local ret = _call_cache[key]

    -- check if we have cached the result of the function call
    if ret == nil then
      -- run the function and cache its return value
      ret = fn(...)
      _call_cache[key] = ret
      print('CACHE(miss): ', key)
    else
      print('CACHE(hit):  ', key)
    end
    -- return the result of the function
    return ret
  end
end
-- short hand for making a cached fn and putting it on an object
local cache_fn_to_object = function(obj, name, fn)
  obj[name] = cache_fn_retval(name, fn)
end

Feel free to remove the comments if you choose to use this. There are quite a few since it’s somewhat unusual lua code.


Examples

Example 1: query caching
-- Get all tasks 
local query_all_tasks = cache_fn_retval('query_all_tasks', function()
  return query [[ from index.tag 'task' ]]
end
Example 2: argument handling
local add_numbers = cache_fn_retval('add_numbers', function(a, b)
  return a + b
end

add_numbers(1, 1) -- caches the result of 1+1
add_numbers(1, 1) -- retrieves the cached value from the previous call
add_numbers(1, 2) -- caches 1+2 (the cache function takes into account the arguments you pass it)
add_numbers(2, 1) -- caches 2+1
add_numbers(1, 2) -- uses cached value for 1+2
Example 3: `cache_fn_to_object()`

You may want to use cache_fn_to_object() instead of cache_fn_retval, since it means you don’t duplicate the name of the function by typing it in the name and in the call to cache_fn_retval():

Lib = {}
-- ...
cache_fn_to_object(Lib, 'add_numbers', function(a, b)
  return a + b
end

A note about key generation

The method for generating cache keys might run into collisions if you pass multiple data types into your functions. For example, if a cached function is called myfunc, the following calls will be the same:

  • myfunc(1, 123)
  • myfunc('1', 123)
  • myfunc(1, '123')
  • myfunc('1', '123')

All of these correspond to the myfunc(1,123) key, and the cache will return the same value for each of them, even if the function would have done something different. This shouldn’t be happen in most cases, but it’s worth knowing about.

A way to fix this would be to embed type information for each argument into the cache keys, but I suspect that would make it slower. I haven’t tested though, so maybe it wouldn’t have much of an impact.

Potential improvements

A way to make this better would be to explore cache invalidation strategies, like setting _call_cache[thing] = nil when pages are created or other events happen. I haven’t found that I really need this yet, so I haven’t looked in to how to do it, but it would be a nice extension to make in a more full-fledged implementation.


Enjoy!

Nice, yeah I think something like this should just be in the standard library.

This is PoC implementation of memoization. It works in Lua for me. Although I haven’t tested in Space Lua yet but it should AFAIK work as well… :slight_smile:

Description:

  • caches function results using all arguments
  • handles multiple return values (including nils)
  • treats nil, NaN and signed zero arguments
  • works with any types in arguments (even tables)
  • simple full reset with clear() function

Implementation:

local NIL, NAN, NEG_ZERO = {}, {}, {}

local function normalize_key(v)
  if v == nil then
    return NIL
  end

  if v ~= v then
    return NAN
  end

  if type(v) == 'number' and v == 0 and 1/v < 0 then
    return NEG_ZERO
  end

  return v
end

function memoize(fn)
  local root = {}

  local function memoized(...)
    local n = select('#', ...)
    local node = root

    for i = 1, n do
      local k = normalize_key(select(i, ...))
      local next_node = node[k]

      if next_node == nil then
        next_node = {}
        node[k] = next_node
      end
      node = next_node
    end

    if node._v then
      -- print("CACHE(hit):", ...)
      return table.unpack(node._v, 1, node._v.n)
    end

    local res = table.pack(fn(...))
    node._v = res

    -- print("CACHE(miss):", ...)
    return table.unpack(res, 1, res.n)
  end

  local function clear()
    root = {}
  end

  return memoized, clear
end

Examples:

Memoized simple function…

local calls = 0
local f, f_clear = memoize(function(a, b)
  calls = calls + 1
  return a + b
end)

assert(f(1, 2) == 3 and calls == 1)   -- miss
assert(f(1, 2) == 3 and calls == 1)   -- hit
assert(f('1', 2) == 3 and calls == 2) -- miss (different key!)
f_clear()
assert(f(1, 2) == 3 and calls == 3)   -- miss after clear

Memoized table identity vs content…

local t1 = { v = 1 }
local t2 = t1
local t3 = { v = 1 }
local calls = 0

local f = memoize(function(t)
  calls = calls + 1
  return t.v
end)

assert(f(t1) == 1 and calls == 1)
assert(f(t2) == 1 and calls == 1) -- same identity
assert(f(t3) == 1 and calls == 2) -- different identity

Memoized recursive function…

local mathlib = {}

mathlib.fib = function(n)
  if n <= 1 then
    return n
  end

  return mathlib.fib(n - 1) + mathlib.fib(n - 2)
end

mathlib.fib = memoize(mathlib.fib)

assert(mathlib.fib(10) == 55)

I hope it helps…


Notice: The assert(f('1', 2) == 3 and calls == 2) assertion cannot work as of now in Space Lua, because of a “coercion in arithmetic” bug also listed in this issue.