gpt4 book ai didi

loops - 为 ST monad 编写高效的迭代循环

转载 作者:行者123 更新时间:2023-12-03 15:02:56 25 4
gpt4 key购买 nike

go worker tail-recursive loop pattern似乎非常适合编写纯代码。为 ST 编写这种循环的等效方法是什么?单子(monad)?更具体地说,我想避免在循环迭代中分配新的堆。我的猜测是它涉及 CPS transformationfixST重写代码,以便在循环中更改的所有值都在每次迭代中传递,从而使寄存器位置(或溢出的情况下的堆栈)在迭代中可用于这些值。我在下面有一个简化的示例(不要尝试运行它 - 它可能会因段错误而崩溃!)涉及一个名为 findSnakes 的函数其中有一个 go worker 模式,但不断变化的状态值不通过累加器参数传递:

{-# LANGUAGE BangPatterns #-}
module Test where

import Data.Vector.Unboxed.Mutable as MU
import Data.Vector.Unboxed as U hiding (mapM_)
import Control.Monad.ST as ST
import Control.Monad.Primitive (PrimState)
import Control.Monad as CM (when,forM_)
import Data.Int

type MVI1 s = MVector (PrimState (ST s)) Int

-- function to find previous y
findYP :: MVI1 s -> Int -> Int -> ST s Int
findYP fp k offset = do
y0 <- MU.unsafeRead fp (k+offset-1) >>= \x -> return $ 1+x
y1 <- MU.unsafeRead fp (k+offset+1)
if y0 > y1 then return y0
else return y1
{-#INLINE findYP #-}

findSnakes :: Vector Int32 -> MVI1 s -> Int -> Int -> (Int -> Int -> Int) -> ST s ()
findSnakes a fp !k !ct !op = go 0 k
where
offset=1+U.length a
go x k'
| x < ct = do
yp <- findYP fp k' offset
MU.unsafeWrite fp (k'+offset) (yp + k')
go (x+1) (op k' 1)
| otherwise = return ()
{-#INLINE findSnakes #-}

cmm输出 ghc 7.6.1 (我对 cmm 的了解有限 - 如果我弄错了请纠正我),我看到了这种调用流程,在 s1tb_info 中有循环(这会导致每次迭代中的堆分配和堆检查):
findSnakes_info -> a1_r1qd_info -> $wa_r1qc_info (new stack allocation, SpLim check)
-> s1sy_info -> s1sj_info: if arg > 1 then s1w8_info else R1 (can't figure out
what that register points to)

-- I am guessing this one below is for go loop
s1w8_info -> s1w7_info (big heap allocation, HpLim check) -> s1tb_info: if arg >= 1
then s1td_info else R1

s1td_info (big heap allocation, HpLim check) -> if arg >= 1 then s1tb_info
(a loop) else s1tb_info (after executing a different block of code)

我的猜测是检查表单 arg >= 1cmm代码是确定是否 go循环是否已终止。如果这是正确的,似乎除非 go循环被重写以通过 yp跨循环,堆分配将跨循环发生新值(我猜 yp导致堆分配)。写 go 的有效方法是什么?在上面的例子中循环?我猜 yp必须在 go 中作为参数传递循环,或通过 fixST 的等效方式或 CPS转型。想不出改写的好方法 go上面的循环以删除堆分配,并将感谢您的帮助。

最佳答案

我重写了您的函数以避免任何显式递归,并删除了一些计算偏移量的冗余操作。这编译成比您的原始函数更好的核心。

顺便说一句,Core 可能是分析编译代码以进行此类分析的更好方法。使用ghc -ddump-simpl查看生成的核心输出,或 ghc-core 等工具

import Control.Monad.Primitive                                                                               
import Control.Monad.ST
import Data.Int
import qualified Data.Vector.Unboxed.Mutable as M
import qualified Data.Vector.Unboxed as U

type MVI1 s = M.MVector (PrimState (ST s)) Int

findYP :: MVI1 s -> Int -> ST s Int
findYP fp offset = do
y0 <- M.unsafeRead fp (offset+0)
y1 <- M.unsafeRead fp (offset+2)
return $ max (y0 + 1) y1

findSnakes :: U.Vector Int32 -> MVI1 s -> Int -> Int -> (Int -> Int -> Int) -> ST s ()
findSnakes a fp k0 ct op = U.mapM_ writeAt $ U.iterateN ct (`op` 1) k0
where writeAt k = do
let offset = U.length a + k
yp <- findYP fp offset
M.unsafeWrite fp (offset + 1) (yp + k)

-- or inline findYP manually
writeAt k = do
let offset = U.length a + k
y0 <- M.unsafeRead fp (offset + 0)
y1 <- M.unsafeRead fp (offset + 2)
M.unsafeWrite fp (offset + 1) (k + max (y0 + 1) y1)

此外,您传递了 U.Vector Int32findSnakes , 只计算它的长度并且从不使用 a再次。为什么不直接传入长度呢?

关于loops - 为 ST monad 编写高效的迭代循环,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/17126923/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com