gpt4 book ai didi

haskell - 有什么办法可以内联递归函数吗?

转载 作者:行者123 更新时间:2023-12-02 19:41:17 27 4
gpt4 key购买 nike

这是我的 previous question 的后续内容我问为什么流融合没有在某个程序中启动。事实证明,问题在于某些函数未内联,而 INLINE 标志将性能提高了约 17x(这展示了内联的重要性!)。

现在,请注意,在最初的问题上,我立即对 incAll64 调用进行了硬编码。现在,假设我创建一个 nTimes 函数,它重复调用一个函数:

module Main where

import qualified Data.Vector.Unboxed as V

{-# INLINE incAll #-}
incAll :: V.Vector Int -> V.Vector Int
incAll = V.map (+ 1)

{-# INLINE nTimes #-}
nTimes :: Int -> (a -> a) -> a -> a
nTimes 0 f x = x
nTimes n f x = f (nTimes (n-1) f x)

main :: IO ()
main = do
let size = 100000000 :: Int
let array = V.replicate size 0 :: V.Vector Int
print $ V.sum (nTimes 64 incAll array)

在这种情况下,仅向 nTimes 添加 INLINE pragma 是没有帮助的,因为 AFAIK GHC 不内联递归函数。有没有什么技巧可以强制 GHC 在编译时扩展 nTimes ,从而恢复预期的性能?

最佳答案

不,但是您可以使用更好的功能。我不是在谈论V.map (+64),它肯定会让事情变得更快,而是在谈论nTimes。我们有三位候选人已经做了 nTimes 所做的事情:

{-# INLINE nTimesFoldr #-}
nTimesFoldr :: Int -> (a -> a) -> a -> a
nTimesFoldr n f x = foldr (.) id (replicate n f) $ x

{-# INLINE nTimesIterate #-}
nTimesIterate :: Int -> (a -> a) -> a -> a
nTimesIterate n f x = iterate f x !! n

{-# INLINE nTimesTail #-}
nTimesTail :: Int -> (a -> a) -> a -> a
nTimesTail n f = go n
where
{-# INLINE go #-}
go n x | n <= 0 = x
go n x = go (n - 1) (f x)

所有版本大约需要 8 秒,而您的版本需要 40 秒。顺便说一句,约阿希姆的版本也需要 8 秒。请注意,iterate 版本在我的系统上占用更多内存。虽然有一个 unroll plugin对于 GHC,它在过去五年内没有更新(它使用自定义 ANNotations)。

根本没有展开?

但是,在我们绝望之前,GHC 实际上尝试内联所有内容的效果如何?让我们使用 nTimesTailnTimes 1:

module Main where
import qualified Data.Vector.Unboxed as V

{-# INLINE incAll #-}
incAll :: V.Vector Int -> V.Vector Int
incAll = V.map (+ 1)

{-# INLINE nTimes #-}
nTimes :: Int -> (a -> a) -> a -> a
nTimes n f = go n
where
{-# INLINE go #-}
go n x | n <= 0 = x
go n x = go (n - 1) (f x)

main :: IO ()
main = do
let size = 100000000 :: Int
let array = V.replicate size 0 :: V.Vector Int
print $ V.sum (nTimes 1 incAll array)
$ stack ghc --package vector -- -O2 -ddump-simpl -dsuppress-all SO.hs
main2 =
case (runSTRep main3) `cast` ...
of _ { Vector ww1_s9vw ww2_s9vx ww3_s9vy ->
case $wgo 1 ww1_s9vw ww2_s9vx ww3_s9vy
of _ { (# ww5_s9w3, ww6_s9w4, ww7_s9w5 #) ->

我们可以就停在那里。 $wgo 是上面定义的 go。即使使用 1 GHC 也不会展开循环。这很令人不安,因为 1 是一个常数。

模板来救援

但是可惜,这一切并非都是徒劳。如果 C++ 程序员能够对编译时常量执行以下操作,那么我们也应该这样做,对吗?

template <int N>
struct Call{
template <class F, class T>
static T call(F f, T && t){
return f(Call<N-1>::call(f,std::forward<T>(t)));
}
};
template <>
struct Call<0>{
template <class F, class T>
static T call(F f, T && t){
return t;
}
};

果然,我们可以,用 TemplateHaskell *:

-- Times.sh
{-# LANGUAGE TemplateHaskell #-}
module Times where

import Control.Monad (when)
import Language.Haskell.TH

nTimesTH :: Int -> Q Exp
nTimesTH n = do
f <- newName "f"
x <- newName "x"

when (n <= 0) (reportWarning "nTimesTH: argument non-positive")

let go k | k <= 0 = VarE x
go k = AppE (VarE f) (go (k - 1))
return $ LamE [VarP f,VarP x] (go n)

nTimesTH 是做什么的?它创建一个新函数,其中第一个名称 f 将应用于第二个名称 x 总共 n 次。 n 现在需要是一个编译时常量,这适合我们,因为循环展开只能使用编译时常量:

$(nTimesTH 0) = \f x -> x
$(nTimesTH 1) = \f x -> f x
$(nTimesTH 2) = \f x -> f (f x)
$(nTimesTH 3) = \f x -> f (f (f x))
...

有效果吗?而且速度快吗?与nTimes相比有多快?让我们尝试另一个 main 来实现:

-- SO.hs
{-# LANGUAGE TemplateHaskell #-}
module Main where
import Times
import qualified Data.Vector.Unboxed as V

{-# INLINE incAll #-}
incAll :: V.Vector Int -> V.Vector Int
incAll = V.map (+ 1)

{-# INLINE nTimes #-}
nTimes :: Int -> (a -> a) -> a -> a
nTimes n f = go n
where
{-# INLINE go #-}
go n x | n <= 0 = x
go n x = go (n - 1) (f x)

main :: IO ()
main = do
let size = 100000000 :: Int
let array = V.replicate size 0 :: V.Vector Int
let vTH = V.sum ($(nTimesTH 64) incAll array)
let vNorm = V.sum (nTimes 64 incAll array)
print $ vTH == vNorm
stack ghc --package vector -- -O2 SO.hs && SO.exe +RTS -t
True
<<ghc: 52000056768 bytes, 66 GCs, 400034700/800026736 avg/max bytes residency (2 samples), 1527M in use, 0.000 INIT (0.000 elapsed), 8.875 MUT (9.119 elapsed), 0.000 GC (0.094 elapsed) :ghc>>

它产生了正确的结果。有多快?让我们再次使用另一个main:

main :: IO ()
main = do
let size = 100000000 :: Int
let array = V.replicate size 0 :: V.Vector Int
print $ V.sum ($(nTimesTH 64) incAll array)
     800,048,112 bytes allocated in the heap                                         
4,352 bytes copied during GC
42,664 bytes maximum residency (1 sample(s))
18,776 bytes maximum slop
764 MB total memory in use (0 MB lost due to fragmentation)

Tot time (elapsed) Avg pause Max pause
Gen 0 1 colls, 0 par 0.000s 0.000s 0.0000s 0.0000s
Gen 1 1 colls, 0 par 0.000s 0.049s 0.0488s 0.0488s

INIT time 0.000s ( 0.000s elapsed)
MUT time 0.172s ( 0.221s elapsed)
GC time 0.000s ( 0.049s elapsed)
EXIT time 0.000s ( 0.049s elapsed)
Total time 0.188s ( 0.319s elapsed)

%GC time 0.0% (15.3% elapsed)

Alloc rate 4,654,825,378 bytes per MUT second

Productivity 100.0% of total user, 58.7% of total elapsed

好吧,将其与 8 进行比较。因此,对于TL;DR:如果您有编译时常量,并且您想根据该常量创建和/或修改代码,请考虑 Template Haskell。

* 请注意,这是我编写的第一个模板 Haskell 代码。小心使用。不要使用太大的 n,否则您可能会得到一个困惑的函数。

关于haskell - 有什么办法可以内联递归函数吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42179783/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com