Performance 在Haskell中执行常量空间嵌套循环的正确方法是什么?

Performance 在Haskell中执行常量空间嵌套循环的正确方法是什么?,performance,loops,haskell,Performance,Loops,Haskell,在Haskell中执行嵌套循环有两种明显的“惯用”方法:使用列表monad或使用forM\uuu替换传统的fors。我设置了一个基准来确定这些是否编译为紧密循环: import Control.Monad.Loop import Control.Monad.Primitive import Control.Monad import Control.Monad.IO.Class import qualified Data.Vector.Unboxed.Mutable as MV import qu

在Haskell中执行嵌套循环有两种明显的“惯用”方法:使用列表monad或使用
forM\uuu
替换传统的
fors
。我设置了一个基准来确定这些是否编译为紧密循环:

import Control.Monad.Loop
import Control.Monad.Primitive
import Control.Monad
import Control.Monad.IO.Class
import qualified Data.Vector.Unboxed.Mutable as MV
import qualified Data.Vector.Unboxed as V

times = 100000
side  = 100

-- Using `forM_` to replace traditional fors
test_a mvec = 
    forM_ [0..times-1] $ \ n -> do
        forM_ [0..side-1] $ \ y -> do
            forM_ [0..side-1] $ \ x -> do
                MV.write mvec (y*side+x) 1

-- Using the list monad to replace traditional forms
test_b mvec = sequence_ $ do
    n <- [0..times-1]
    y <- [0..side-1]
    x <- [0..side-1]
    return $ MV.write mvec (y*side+x) 1

main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    -- test_a mvec
    -- test_b mvec
    vec' <- V.unsafeFreeze mvec :: IO (V.Vector Int)
    print $ V.sum vec'
这个等价的JavaScript程序需要1s才能完成,超过了Haskell的unboxed vectors,这是不寻常的,这表明Haskell不是在常量空间中运行循环,而是执行分配。然后我发现了一个库,它声称提供类型保证的紧循环
Control.Monad.Loop

-- Using `for` from Control.Monad.Loop
test_c mvec = exec_ $ do
    n <- for 0 (< times) (+ 1)
    x <- for 0 (< side) (+ 1)
    y <- for 0 (< side) (+ 1)
    liftIO (MV.write mvec (y*side+x) 1)
——在Control.Monad.Loop中使用'for'
test_c mvec=exec_uu$do

根据我的经验,表单[0..n-1]
可以很好地执行,但不幸的是它不可靠。只需将
内联
pragma添加到
测试a
并使用
-O2
即可使其运行得更快(对我来说是4s到1s),但手动内联(复制粘贴)会再次降低运行速度

一个更可靠的功能是从中实现的

-- | Simple for loop.  Counts from /start/ to /end/-1.
for :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
for n0 !n f = loop n0
  where
    loop i | i == n    = return ()
           | otherwise = f i >> loop (i+1)
{-# INLINE for #-}
使用类似于
表单
的列表:

test_d :: MV.IOVector Int -> IO ()
test_d mv =
  for 0 times $ \_ ->
    for 0 side $ \i ->
      for 0 side $ \j ->
        MV.unsafeWrite mv (i*side + j) 1

但是性能可靠(对我来说是0.85秒),没有任何分配列表的风险。

用GHC编写紧凑的变异代码有时会很棘手。我将要写一些不同的东西,可能是以一种更为杂乱无章的方式;医生比我更喜欢

首先,我们应该在任何情况下使用GHC 7.10,因为
表单
和列表单子解决方案永远不会融合

另外,我用
MV.unsafeWrite
替换了
MV.write
,部分原因是它速度更快,但更重要的是它减少了合成核心中的一些混乱。从现在起,运行时统计信息将使用
unsafeWrite
引用代码

可怕的风在飘 即使使用GHC 7.10,我们也应该首先注意所有那些
[0..times-1]
[0..side-1]
表达式,因为如果我们不采取必要的步骤,它们每次都会破坏性能。问题是它们是恒定的范围,
-ffull laziness
(默认情况下在
-O
上启用)将它们浮动到顶层。这就防止了列表融合,而且在
Int#
范围内迭代比在装箱的
Int
-s列表上迭代要便宜,所以这是一个非常糟糕的优化

让我们看看未更改代码(除了使用
unsafeWrite
)的一些以秒为单位的运行时<代码>使用ghc-O2-fllvm,我使用
+RTS-s
进行计时

test_a: 1.6
test_b: 6.2
test_c: 0.6
对于GHC核心查看,我使用了
GHC-O2-ddump simp-dsuppress all-dno suppress类型签名

测试a
的情况下,
[0..99]
范围被提升:

main4 :: [Int]
main4 = eftInt 0 99 -- means "enumFromTo" for Int.
尽管最外层的
[0..9999]
循环融合为尾部递归辅助对象:

letrec {
          a3_s7xL :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a3_s7xL =
            \ (x_X5zl :: Int#) (s1_X4QY :: State# RealWorld) ->
              case a2_s7xF 0 s1_X4QY of _ { (# ipv2_a4NA, ipv3_a4NB #) ->
              case x_X5zl of wild_X1S {
                __DEFAULT -> a3_s7xL (+# wild_X1S 1) ipv2_a4NA;
                99999 -> (# ipv2_a4NA, () #)
              }
              }; }
测试b
的情况下,再次仅提升
[0..99]
。然而,
test\u b
要慢得多,因为它必须构建实际的
[IO()]
列表并对其排序。至少GHC足够明智,只为两个内部循环构建一个
[IO()]
,然后对其进行排序
10000次

 let {
          lvl7_s4M5 :: [IO ()]
          lvl7_s4M5 = -- omitted
        letrec {
          a2_s7Av :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7Av =
            \ (x_a5xi :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7Au
                  :: [IO ()] -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7Au =
                  \ (ds_a4Nu :: [IO ()]) (eta1_X1c :: State# RealWorld) ->
                    case ds_a4Nu of _ {
                      [] ->
                        case x_a5xi of wild1_X1y {
                          __DEFAULT -> a2_s7Av (+# wild1_X1y 1) eta1_X1c;
                          99999 -> (# eta1_X1c, () #)
                        };
                      : y_a4Nz ys_a4NA ->
                        case (y_a4Nz `cast` ...) eta1_X1c
                        of _ { (# ipv2_a4Nf, ipv3_a4Ng #) ->
                        a3_s7Au ys_a4NA ipv2_a4Nf
                        }
                    }; } in
              a3_s7Au lvl7_s4M5 eta_B1; } in
-- omitted
我们怎样才能补救呢?我们可以用
{-#OPTIONS\u GHC-fno full laziness}
来解决这个问题。这对我们的情况确实有很大帮助:

test_a: 0.5
test_b: 0.48
test_c: 0.5
或者,我们可以随意使用
INLINE
pragmas。显然,let浮动完成后的内联函数保持了良好的性能。我发现GHC甚至在没有pragma的情况下内联我们的测试函数,但是显式pragma导致它只在let floating之后内联。例如,这会导致良好的性能,而不会出现完全的惰性

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
{-# INLINE test_a #-}
但过早内联会导致性能不佳:

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
{-# INLINE [~2] test_a #-} -- "inline before the first phase please"
这个
INLINE
解决方案的问题在于,面对GHC的浮动冲击,它相当脆弱。例如,手动内联不能保持性能。下面的代码速度很慢,因为与内联[~2]
类似,它给GHC一个浮动的机会:

main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1    
在真正恒定的空间中循环 起初,我对堆分配的
+RTS-s
数据感到非常困惑<代码>测试a
非平凡地分配有
-fn完全惰性
,也有
测试c
没有完全惰性,这些分配与
迭代次数成线性比例,但是
测试b
完全惰性只分配给向量:

-- with -fno-full-laziness, no INLINE pragmas
test_a: 242,521,008 bytes
test_b: 121,008 bytes
test_c: 121,008 bytes -- but 240,120,984 with full laziness!
而且,
test\u c
INLINE
pragmas在这种情况下根本没有帮助

我花了一些时间试图在相关程序的核心中找到堆分配的迹象,但没有成功,直到我突然意识到:GHC堆栈帧在堆上,包括主线程的帧,而执行堆分配的函数实际上是在最多三个堆栈帧中运行三次嵌套循环。
+RTS-s
注册的堆分配只是堆栈帧的不断弹出和推送

这在以下代码的核心中非常明显:

{-# OPTIONS_GHC -fno-full-laziness #-}

-- ...

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    test_a mvec
现在堆分配保持完全相同,因为最里面的循环是尾部递归的,并且使用单个帧。通过以下更改,堆分配减半(至124921008字节),因为我们推送和弹出的帧数减半:

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-50] $ \ y -> -- change here
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
test_b
test_c
(没有完全惰性)编译为在单个堆栈框架内使用嵌套case构造的代码,并遍历索引以查看哪个索引应该递增。请参阅核心部分以了解以下主要内容:

{-# LANGUAGE BangPatterns #-} -- later I'll talk about this
{-# OPTIONS_GHC -fno-full-laziness #-}

main = do
    let vec = V.generate (side*side) (const 0)
    !mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    test_c mvec
在GHC-7.10.2中,这会在不进行优化的情况下打印一次
“$$$”
,而在
-O2
中打印三次。看来,使用GHC-7.10,我们无法通过
-fno state hack
(w
main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
  \ (s_a5HK :: State# RealWorld) ->
    case divInt# 9223372036854775807 8 of ww4_a5vr { __DEFAULT ->

    -- start of vector creation ----------------------
    case tagToEnum# (># 10000 ww4_a5vr) of _ {
      False ->
        case newByteArray# 80000 (s_a5HK `cast` ...)
        of _ { (# ipv_a5fv, ipv1_a5fw #) ->
        letrec {
          $s$wa_s8jS
            :: Int#
               -> Int#
               -> State# (PrimState IO)
               -> (# State# (PrimState IO), Int #)
          $s$wa_s8jS =
            \ (sc_s8jO :: Int#)
              (sc1_s8jP :: Int#)
              (sc2_s8jR :: State# (PrimState IO)) ->
              case tagToEnum# (<# sc1_s8jP 10000) of _ {
                False -> (# sc2_s8jR, I# sc_s8jO #);
                True ->
                  case writeIntArray# ipv1_a5fw sc_s8jO 0 (sc2_s8jR `cast` ...)
                  of s'#_a5Gn { __DEFAULT ->
                  $s$wa_s8jS (+# sc_s8jO 1) (+# sc1_s8jP 1) (s'#_a5Gn `cast` ...)
                  }
              }; } in
        case $s$wa_s8jS 0 0 (ipv_a5fv `cast` ...)
        -- end of vector creation -------------------

        of _ { (# ipv6_a4Hv, ipv7_a4Hw #) ->
        letrec {
          a2_s7MJ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7MJ =
            \ (x_a5Ho :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7ME :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7ME =
                  \ (x1_X5Id :: Int#) (eta1_XR :: State# RealWorld) ->
                    case ipv7_a4Hw of _ { I# dt4_a5x6 ->
                    case writeIntArray#
                           (ipv1_a5fw `cast` ...) (*# x1_X5Id 100) 1 (eta1_XR `cast` ...)
                    of s'#_a5Gn { __DEFAULT ->
                    letrec {
                      a4_s7Mz :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                      a4_s7Mz =
                        \ (x2_X5J8 :: Int#) (eta2_X1U :: State# RealWorld) ->
                          case writeIntArray#
                                 (ipv1_a5fw `cast` ...)
                                 (+# (*# x1_X5Id 100) x2_X5J8)
                                 1
                                 (eta2_X1U `cast` ...)
                          of s'#1_X5Hf { __DEFAULT ->
                          case x2_X5J8 of wild_X2o {
                            __DEFAULT -> a4_s7Mz (+# wild_X2o 1) (s'#1_X5Hf `cast` ...);
                            99 -> (# s'#1_X5Hf `cast` ..., () #)
                          }
                          }; } in
                    case a4_s7Mz 1 (s'#_a5Gn `cast` ...)
                    of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
                    case x1_X5Id of wild_X1e {
                      __DEFAULT -> a3_s7ME (+# wild_X1e 1) ipv2_a4QH;
                      99 -> (# ipv2_a4QH, () #)
                    }
                    }
                    }
                    }; } in
              case a3_s7ME 0 eta_B1 of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
              case x_a5Ho of wild_X1a {
                __DEFAULT -> a2_s7MJ (+# wild_X1a 1) ipv2_a4QH;
                99999 -> (# ipv2_a4QH, () #)
              }
              }; } in
        a2_s7MJ 0 (ipv6_a4Hv `cast` ...)
        }
        };
      True ->
        case error
               (unpackAppendCString#
                  "Primitive.basicUnsafeNew: length to large: "#
                  (case $wshowSignedInt 0 10000 ([])
                   of _ { (# ww5_a5wm, ww6_a5wn #) ->
                   : ww5_a5wm ww6_a5wn
                   }))
        of wild_00 {
        }
    }
    }

main :: IO ()
main = main1 `cast` ...

main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)

main :: IO ()
main = main2 `cast` ...
test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-50] $ \ x -> -- change here
                MV.unsafeWrite mvec (y*side+x) 1
test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-50] $ \ y -> -- change here
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
{-# LANGUAGE BangPatterns #-} -- later I'll talk about this
{-# OPTIONS_GHC -fno-full-laziness #-}

main = do
    let vec = V.generate (side*side) (const 0)
    !mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    test_c mvec
main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
  \ (s_a5Iw :: State# RealWorld) ->
    case divInt# 9223372036854775807 8 of ww4_a5vT { __DEFAULT ->

    -- start of vector creation ----------------------
    case tagToEnum# (># 10000 ww4_a5vT) of _ {
      False ->
        case newByteArray# 80000 (s_a5Iw `cast` ...)
        of _ { (# ipv_a5g3, ipv1_a5g4 #) ->
        letrec {
          $s$wa_s8ji
            :: Int#
               -> Int#
               -> State# (PrimState IO)
               -> (# State# (PrimState IO), Int #)
          $s$wa_s8ji =
            \ (sc_s8je :: Int#)
              (sc1_s8jf :: Int#)
              (sc2_s8jh :: State# (PrimState IO)) ->
              case tagToEnum# (<# sc1_s8jf 10000) of _ {
                False -> (# sc2_s8jh, I# sc_s8je #);
                True ->
                  case writeIntArray# ipv1_a5g4 sc_s8je 0 (sc2_s8jh `cast` ...)
                  of s'#_a5GP { __DEFAULT ->
                  $s$wa_s8ji (+# sc_s8je 1) (+# sc1_s8jf 1) (s'#_a5GP `cast` ...)
                  }
              }; } in
        case $s$wa_s8ji 0 0 (ipv_a5g3 `cast` ...)
        of _ { (# ipv6_a4MX, ipv7_a4MY #) ->
        case ipv7_a4MY of _ { I# dt4_a5xy ->
        -- end of vector creation

        letrec {
          a2_s7Q6 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7Q6 =
            \ (x_a5HT :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7Q5 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7Q5 =
                  \ (x1_X5J9 :: Int#) (eta1_XP :: State# RealWorld) ->
                    letrec {
                      a4_s7MZ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                      a4_s7MZ =
                        \ (x2_X5Jl :: Int#) (s1_X4Xb :: State# RealWorld) ->
                          case writeIntArray#
                                 (ipv1_a5g4 `cast` ...)
                                 (+# (*# x1_X5J9 100) x2_X5Jl)
                                 1
                                 (s1_X4Xb `cast` ...)
                          of s'#_a5GP { __DEFAULT ->

                          -- the interesting part! ------------------
                          case x2_X5Jl of wild_X1y {
                            __DEFAULT -> a4_s7MZ (+# wild_X1y 1) (s'#_a5GP `cast` ...);
                            99 ->
                              case x1_X5J9 of wild1_X1o {
                                __DEFAULT -> a3_s7Q5 (+# wild1_X1o 1) (s'#_a5GP `cast` ...);
                                99 ->
                                  case x_a5HT of wild2_X1c {
                                    __DEFAULT -> a2_s7Q6 (+# wild2_X1c 1) (s'#_a5GP `cast` ...);
                                    99999 -> (# s'#_a5GP `cast` ..., () #)
                                  }
                              }
                          }
                          }; } in
                    a4_s7MZ 0 eta1_XP; } in
              a3_s7Q5 0 eta_B1; } in
        a2_s7Q6 0 (ipv6_a4MX `cast` ...)
        }
        }
        };
      True ->
        case error
               (unpackAppendCString#
                  "Primitive.basicUnsafeNew: length to large: "#
                  (case $wshowSignedInt 0 10000 ([])
                   of _ { (# ww5_a5wO, ww6_a5wP #) ->
                   : ww5_a5wO ww6_a5wP
                   }))
        of wild_00 {
        }
    }
    }

main :: IO ()
main = main1 `cast` ...

main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)

main :: IO ()
main = main2 `cast` ...
import Control.Monad
import Debug.Trace

expensive :: String -> String
expensive x = trace "$$$" x

main :: IO ()
main = do
  str <- fmap expensive getLine
  replicateM_ 3 $ print str
main :: IO ()
main = do
  !str <- fmap expensive getLine
  replicateM_ 3 $ print str