haskell 为什么在并行化K均值时没有得到任何增益?

gjmwrych  于 2023-03-03  发布在  其他
关注(0)|答案(1)|浏览(152)

我正在Haskell中学习Simon马洛的书中的并行编程。我正在学习第3章。我不明白为什么在6核机器上使用-N6标志启动此代码对性能没有任何影响。我得到了similar problem,然后才意识到添加更多数据进行处理会使并行和单线程执行之间的差异更加显著。但在这种情况下,即使添加更多的数据也不会给执行时间带来差异。在我的练习中,我试图找到世界城市的质心。为了给它更多的数据,我创建了一些“假”城市。
以下是城市文件的前五行:
worldcities.csv

"city","city_ascii","lat","lng","country","iso2","iso3","admin_name","capital","population","id"
"Tokyo","Tokyo","35.6839","139.7744","Japan","JP","JPN","Tōkyō","primary","39105000","1392685764"
"Jakarta","Jakarta","-6.2146","106.8451","Indonesia","ID","IDN","Jakarta","primary","35362000","1360771077"
"Delhi","Delhi","28.6667","77.2167","India","IN","IND","Delhi","admin","31870000","1356872604"
"Manila","Manila","14.6000","120.9833","Philippines","PH","PHL","Manila","primary","23971000","1608618140"

下面是我用来添加更多“城市”的代码:

main :: IO ()
main = do
    [c] <- getArgs
    cities <- getCities c
    let target = "wc_extended.csv"
    BS.writeFile target (encode [("city" :: BS.ByteString, "lat" :: BS.ByteString, "lng" :: BS.ByteString)])
    forM_ cities $ \(City name (Point lat lng)) -> BS.appendFile target $ encode [
            (name,lat,lng)
        ,   ("A" <> name,lat-10,lng+10)
        ,   ("B" <> name,lat-20,lng+20)
        ,   ("C" <> name,lat-30,lng+30)
        ,   ("D" <> name,lat-40,lng+40)
        ,   ("E" <> name,lat-50,lng+50)
        ]

这是我的K-均值聚类程序的源文件。
Types.hs

module Types where

import Data.ByteString (ByteString)
import System.Random
import System.Random.Stateful

data Point = Point {
        lat :: !Double
    ,   lng :: !Double
} deriving (Eq,Show)

instance Uniform Point where
    uniformM g = do
                    lat <- uniformRM (-180, 180) g
                    lng <- uniformRM (-180, 180) g
                    return $ Point lat lng

instance Semigroup Point where
    (Point lat lng) <> (Point lat' lng') = Point (lat + lat') (lng + lng')

instance Monoid Point where
    mempty = Point 0 0

sqDistance :: Point -> Point -> Double
sqDistance (Point lat lng) (Point lat' lng') = (lat-lat')^2 + (lng-lng')^2 

data City = City {
        name :: ByteString
    ,   location :: Point
} deriving Show

data Cluster = Cluster {
        cId :: Int
    ,   center :: Point
} deriving (Eq, Show)

data PointSum = PointSum !Int !Point

instance Semigroup PointSum where
    (PointSum c p) <> (PointSum c' p') = PointSum (c+c') (p <> p')

instance Monoid PointSum where
    mempty = PointSum 0 mempty

addToPointSum :: Point -> PointSum -> PointSum
addToPointSum point' (PointSum count point) = PointSum (count+1) $ point <> point'

pointSumToCluster :: Int -> PointSum -> Cluster
pointSumToCluster i (PointSum count (Point lat lng)) = Cluster {
      cId = i
    , center = Point (lat / fromIntegral count) (lng / fromIntegral count)
}

CitiesLoader.hs

{-# LANGUAGE OverloadedStrings #-}
module CitiesLoader where

import Data.Attoparsec.ByteString
import Data.Csv
import Data.Vector (Vector)
import qualified Data.Vector as V 
import qualified Data.ByteString as BS
import qualified Data.ByteString.UTF8 as UTF8
import Data.Csv.Parser (csvWithHeader)
import Data.HashMap.Strict ( (!) )

import Types

getCSV :: FilePath -> IO (Vector NamedRecord)
getCSV path = do
    raw <- BS.readFile path
    case parseOnly (csvWithHeader defaultDecodeOptions) raw of
        Left error -> do
            putStrLn $ "Error during parsing: " <> error <> ", returned empty result"
            return mempty
        Right (_, values) -> return values

extractCities :: Vector NamedRecord -> Vector City
extractCities = fmap f
                where f vmap = City (vmap ! "city") $ Point ((read . UTF8.toString) $ vmap ! "lat") ((read . UTF8.toString) $ vmap ! "lng")

getCities :: FilePath -> IO (Vector City)
getCities = (fmap . fmap) extractCities getCSV

Clustering.hs

module Clustering where

import Types
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as M
import Data.Function (on)
import Data.List (minimumBy)
import Control.Monad.Trans.Except
import Control.Parallel.Strategies

assign :: Int -> [Cluster] -> Vector City -> Vector PointSum
assign n clusters points = V.create $ do
    vec <- M.replicate n mempty
    let addpoint (City _ p) = M.modify vec (addToPointSum p) (cId $ nearest p)
    V.mapM_ addpoint points
    return vec
    where nearest p = fst $ minimumBy (compare `on` snd) [(c, sqDistance p (center c)) | c <- clusters]

makeNewClusters :: Vector PointSum -> [Cluster]
makeNewClusters vec = [pointSumToCluster i ps | (i,ps@(PointSum count _)) <- zip [0..] (V.toList vec), count > 0]

step :: Int -> Vector City -> [Cluster] -> [Cluster]
step n cities clusters = makeNewClusters $ assign n clusters cities

kmeansSeq :: Int -> Vector City -> [Cluster] -> Except String [Cluster]
kmeansSeq limit cities clusters = loop 0 clusters
                            where loop n c | n > limit = throwE "reached loop limit"
                                  loop n c = let c' = step nClusters cities c
                                                     in if c' == c
                                                            then return c'
                                                            else loop (n+1) c'
                                  nClusters = length clusters

split :: Int -> Vector a -> [Vector a]
split numChunks xs = chunk (V.length xs `quot` numChunks) xs

chunk :: Int -> Vector a -> [Vector a]
chunk n xs | V.null xs = []
chunk n xs = as : chunk n bs
    where (as, bs) = V.splitAt n xs

combine :: Vector PointSum -> Vector PointSum -> Vector PointSum
combine = V.zipWith (<>)

parStepsStrat :: Int -> [Vector City] -> [Cluster] -> [Cluster]
parStepsStrat n pointss clusters = makeNewClusters $ foldr1 combine (map (assign n clusters) pointss `using` parList rseq)

kMeansStrat :: Int -> Int -> Vector City -> [Cluster] -> Except String [Cluster]
kMeansStrat limit numChunks points clusters = loop 0 clusters
                                        where loop n clusters | n > limit = throwE "reached loop limit"
                                              loop n clusters = let c' = parStepsStrat nClusters chunks clusters
                                                                         in if c' == clusters
                                                                            then return c'
                                                                            else loop (n+1) c'
                                              chunks = split numChunks points
                                              nClusters = length clusters

Main.hs

{-# LANGUAGE OverloadedStrings#-}

module Main where
import System.Environment (getArgs)
import CitiesLoader (getCities)
import System.Random
import System.Random.Stateful (newIOGenM, uniformListM)
import Types
import Clustering
import Control.Monad.Trans.Except (runExcept)
import Data.Vector(forM_)
import Data.Csv(encode)
import qualified Data.ByteString.Lazy as BS

main :: IO ()
main = do
    [c] <- getArgs
    cities <- getCities c
    print (length cities)
    let seed = mkStdGen $ length cities
    g <- newIOGenM seed
    centroids <- uniformListM 1000 g
    let clusters = zipWith Cluster [0..] centroids
    case runExcept (kMeansStrat 10000 6 cities clusters) of
        Left err -> putStrLn err
        Right c -> print c

当我在包含257431条记录的文件上运行此命令时,我得到了以下执行时间:
阴谋执行-- wc_extended.csv +RTS -N1 -s -l
总时间为266.631秒(经过266.351秒)
和螺纹镜轮廓

电缆执行-- wc_extended.csv +RTS -N6 -s -l
总时间1737.342s(已用340.016s)(执行时间甚至增加)
和螺纹镜轮廓

sr4lhrrt

sr4lhrrt1#

尝试使用parList rdeepseq而不是parList rseq,当然您需要PointSumNFData示例,但我相信您可以弄清楚如何实现。
问题是您只将计算并行化到WHNF,这意味着没有任何工作是并行完成的。
如果信息不够,请随意发表评论,我可以将其扩展为更详细的答案

相关问题