Lua尾调用优化详解

深入理解Lua尾调用优化机制,掌握尾递归消除、状态机实现和无限递归的安全写法。

什么是尾调用

尾调用(tail call)是指一个函数在其最后一步操作中调用另一个函数,且调用完成后不需要做任何额外工作就直接返回。在 Lua 中,尾调用不会消耗额外的栈空间,因为 Lua 会复用当前栈帧。

-- 这是尾调用:调用 g() 后 f 直接返回
function f(x)
    return g(x)
end

-- 这不是尾调用:调用 g() 后还需要执行加法
function f(x)
    return g(x) + 1
end

-- 这不是尾调用:调用 g() 后还需要执行 print
function f(x)
    g(x)
    print("done")
end

Lua 对尾调用的处理方式是栈帧复用:当 f 尾调用 g 时,Lua 会丢弃 f 的栈帧,直接在 f 的栈帧上执行 g。这意味着无论嵌套多少层尾调用,栈空间始终保持不变。

尾调用 vs 普通递归

普通递归(会栈溢出)

function factorial_bad(n, acc)
    acc = acc or 1
    if n <= 1 then
        return acc
    end
    -- 这不是尾调用!因为递归返回后还要做乘法
    return n * factorial_bad(n - 1, acc)
end

-- 当 n 很大时会栈溢出
local ok, err = pcall(factorial_bad, 100000)
print(ok, err)  -- false  stack overflow

尾递归版本(安全)

function factorial_good(n, acc)
    acc = acc or 1
    if n <= 1 then
        return acc
    end
    -- 这是尾调用:递归后直接返回
    return factorial_good(n - 1, n * acc)
end

-- 无论 n 多大都不会栈溢出
print(factorial_good(100000))  -- 正常执行(虽然结果溢出)

尾调用的语法要求

Lua 中只有 return func(args) 这种形式才是尾调用:

-- 是尾调用
function a() return b() end

-- 不是尾调用(返回值需要调整)
function a() return b(), 1 end

-- 不是尾调用(返回值需要调整)
function a() return b(), c() end

-- 不是尾调用(需要额外操作)
function a() return (b()) end

-- 不是尾调用(需要类型转换)
function a() return tonumber(b()) end

-- 不是尾调用(在赋值后)
function a()
    local x = b()
    return x
end

-- 是尾调用(在条件分支中)
function a(x)
    if x > 0 then
        return b(x)
    else
        return c(x)
    end
end

相互尾递归

尾调用不限于自递归,两个函数可以相互尾调用:

-- 判断奇偶(相互尾递归)
local is_even, is_odd

function is_even(n)
    if n == 0 then return true end
    return is_odd(n - 1)
end

function is_odd(n)
    if n == 0 then return false end
    return is_even(n - 1)
end

print(is_even(1000000))  -- true(不会栈溢出)
print(is_odd(1000001))   -- true

用尾调用实现状态机

尾调用是实现状态机的理想工具,每个状态是一个函数,状态转换通过尾调用实现:

-- 简单的自动售货机状态机
local function state_idle()
    print("状态: 待机,等待投币...")
    local input = io.read()
    if input == "coin" then
        return state_has_coin()
    end
    return state_idle()
end

local function state_has_coin()
    print("状态: 已投币,选择商品 (a/b/c)...")
    local input = io.read()
    if input == "a" or input == "b" or input == "c" then
        return state_dispensing(input)
    end
    print("无效选择")
    return state_has_coin()
end

local function state_dispensing(item)
    print("正在出货: " .. item)
    print("出货完成!")
    return state_idle()
end

-- 启动状态机(可以无限运行而不栈溢出)
-- state_idle()

游戏 AI 状态机

-- 游戏 NPC 的 AI 状态机
local function make_npc(name, hp)
    local self = {name = name, hp = hp, target = nil}

    function self:patrol()
        print(self.name .. " 正在巡逻...")
        -- 检查是否发现敌人
        if self:can_see_enemy() then
            return self:chase()
        end
        return self:patrol()
    end

    function self:chase()
        print(self.name .. " 正在追击...")
        if self:in_attack_range() then
            return self:attack()
        end
        if self:lost_target() then
            return self:patrol()
        end
        return self:chase()
    end

    function self:attack()
        print(self.name .. " 正在攻击!")
        if not self:in_attack_range() then
            return self:chase()
        end
        if self.hp < 20 then
            return self:flee()
        end
        return self:attack()
    end

    function self:flee()
        print(self.name .. " 正在逃跑!")
        if self.hp > 50 then
            return self:patrol()
        end
        return self:flee()
    end

    -- 模拟辅助方法
    function self:can_see_enemy() return math.random() > 0.7 end
    function self:in_attack_range() return math.random() > 0.5 end
    function self:lost_target() return math.random() > 0.8 end

    return self
end

尾调用实现链表遍历

-- 链表定义
local function cons(head, tail)
    return {head = head, tail = tail}
end

-- 尾递归遍历
local function print_list(node)
    if node == nil then return end
    print(node.head)
    return print_list(node.tail)
end

-- 尾递归求长度
local function list_length(node, acc)
    acc = acc or 0
    if node == nil then return acc end
    return list_length(node.tail, acc + 1)
end

-- 尾递归反转
local function list_reverse(node, acc)
    if node == nil then return acc end
    return list_reverse(node.tail, cons(node.head, acc))
end

-- 构建链表 1 -> 2 -> 3 -> nil
local list = cons(1, cons(2, cons(3, nil)))
print_list(list)
print("长度:", list_length(list))  -- 3

尾调用实现解释器

-- 简单的算术表达式解释器
local function eval(expr, env)
    if type(expr) == "number" then
        return expr
    end

    if type(expr) == "string" then
        return env[expr]
    end

    local op = expr[1]
    if op == "+" then
        return eval(expr[2], env) + eval(expr[3], env)
    elseif op == "-" then
        return eval(expr[2], env) - eval(expr[3], env)
    elseif op == "*" then
        return eval(expr[2], env) * eval(expr[3], env)
    elseif op == "/" then
        return eval(expr[2], env) / eval(expr[3], env)
    elseif op == "if" then
        -- 条件分支可以尾调用
        if eval(expr[2], env) then
            return eval(expr[3], env)
        else
            return eval(expr[4], env)
        end
    end
end

-- 表达式: (3 + 5) * 2
local result = eval({"*", {"+", 3, 5}, 2}, {})
print(result)  -- 16

-- 表达式: if (x > 5) then x * 2 else x + 1
local result2 = eval(
    {"if", {">", "x", 5}, {"*", "x", 2}, {"+", "x", 1}},
    {x = 10}
)
-- 注意:这里没有实现 > 运算符,仅作结构展示

尾递归转迭代

当不能使用尾调用时,可以手动转换为迭代:

-- 尾递归版本
local function sum_tail(n, acc)
    acc = acc or 0
    if n <= 0 then return acc end
    return sum_tail(n - 1, acc + n)
end

-- 等价的迭代版本
local function sum_iter(n)
    local acc = 0
    while n > 0 do
        acc = acc + n
        n = n - 1
    end
    return acc
end

-- 树遍历:尾递归
local function tree_sum_tail(node, acc)
    acc = acc or 0
    if not node then return acc end
    -- 左子树用尾递归
    acc = tree_sum_iter(node.left, acc + node.value)
    -- 右子树尾调用
    return tree_sum_tail(node.right, acc)
end

检测尾调用

使用 debug 库可以验证函数调用是否使用了尾调用:

function non_tail()
    local info = debug.getinfo(2, "t")
    return info and info.what
end

function tail_call()
    return non_tail()  -- 尾调用
end

function no_tail_call()
    local x = non_tail()  -- 不是尾调用
    return x
end

-- 使用 debug.traceback 查看调用栈
local function check_tail()
    local co = coroutine.create(function()
        return tail_call()
    end)
    coroutine.resume(co)
end

注意事项

使用尾调用优化需要注意以下要点:

  • Lua 的尾调用优化是编译器保证的,不是可选优化
  • 只有 return func(args) 形式才是尾调用,return func(args), extra 不是
  • 尾调用会丢失调用栈信息,debug.traceback 中看不到被复用的栈帧
  • LuaJIT 对尾调用有同样的支持
  • 尾调用不影响 xpcall 的错误处理,错误会正确传播
  • 在设计状态机时,确保每个状态转换都是尾调用
  • 如果需要保留调用栈(如调试),应避免使用尾调用

继续阅读

探索更多技术文章

浏览归档,发现更多关于系统设计、工具链和工程实践的内容。

全部文章 返回首页