-- | Implements Tarjan's algorithm for computing the strongly connected
-- components of a graph.  For more details see:
-- <http://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm>
{-# LANGUAGE Rank2Types, Trustworthy #-}
module Data.Graph.ArraySCC(scc) where

import Data.Graph(Graph,Vertex)
import Data.Array.ST(STUArray, newArray, readArray, writeArray)
import Data.Array as A
import Data.Array.Unsafe(unsafeFreeze)
import Control.Monad.ST
import Control.Monad(ap)

-- | Computes the strongly connected components (SCCs) of the graph in
-- O(#edges + #vertices) time.  The resulting tuple contains:
--
--   * A (reversed) topologically sorted list of SCCs.
--     Each SCCs is assigned a unique identifier of type 'Int'.
--
--   * An O(1) mapping from vertices in the original graph to the identifier
--     of their SCC.  This mapping will raise an \"out of bounds\"
--     exception if it is applied to integers that do not correspond to
--     vertices in the input graph.
--
-- This function assumes that the adjacency lists in the original graph
-- mention only nodes that are in the graph. Violating this assumption
-- will result in \"out of bounds\" array exception.
scc :: Graph -> ([(Int,[Vertex])], Vertex -> Int)
scc :: Graph -> ([(Int, [Int])], Int -> Int)
scc Graph
g = (forall s. ST s ([(Int, [Int])], Int -> Int))
-> ([(Int, [Int])], Int -> Int)
forall a. (forall s. ST s a) -> a
runST (
  do ixes <- (Int, Int) -> Int -> ST s (STUArray s Int Int)
forall i. Ix i => (i, i) -> Int -> ST s (STUArray s i Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Graph -> (Int, Int)
forall i e. Array i e -> (i, i)
bounds Graph
g) Int
0
     lows <- newArray (bounds g) 0
     s <- roots g ixes lows (S [] 1 [] 1) (indices g)
     sccm <- unsafeFreeze ixes
     return (sccs s, \Int
i -> Array Int Int
sccm Array Int Int -> Int -> Int
forall i e. Ix i => Array i e -> i -> e
! Int
i)
  )

type Func s a =
     Graph                    -- The original graph
  -> STUArray s Vertex Int    -- Index in DFS traversal, or SCC for vertex.
    -- Legend for the index array:
    --    0:    Node not visited
    --    -ve:  Node is on the stack with the given number
    --    +ve:  Node belongs to the SCC with the given number
  -> STUArray s Vertex Int    -- Least reachable node
  -> S                        -- State
  -> a

data S = S { S -> [Int]
stack    :: ![Vertex]          -- ^ Traversal stack
           , S -> Int
num      :: !Int               -- ^ Next node number
           , S -> [(Int, [Int])]
sccs     :: ![(Int,[Vertex])]  -- ^ Finished SCCs
           , S -> Int
next_scc :: !Int               -- ^ Next SCC number
           }


roots :: Func s ([Vertex] -> ST s S)
roots :: forall s. Func s ([Int] -> ST s S)
roots Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
st (Int
v:[Int]
vs) =
  do i <- STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
ixes Int
v
     if i == 0 then do s1 <- from_root g ixes lows st v
                       roots g ixes lows s1 vs
               else roots g ixes lows st vs
roots Graph
_ STUArray s Int Int
_ STUArray s Int Int
_ S
s [] = S -> ST s S
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return S
s


from_root :: Func s (Vertex -> ST s S)
from_root :: forall s. Func s (Int -> ST s S)
from_root Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
s Int
v =
  do let me :: Int
me = S -> Int
num S
s
     STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
ixes Int
v (Int -> Int
forall a. Num a => a -> a
negate Int
me)
     STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
lows Int
v Int
me
     newS <- Func s (Int -> [Int] -> ST s S)
forall s. Func s (Int -> [Int] -> ST s S)
check_adj Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows
                        S
s { stack = v : stack s, num = me + 1 } Int
v (Graph
g Graph -> Int -> [Int]
forall i e. Ix i => Array i e -> i -> e
! Int
v)

     x <- readArray lows v
     if x < me then return newS else
       case span (/= v) (stack newS) of
         ([Int]
as,Int
b:[Int]
bs) ->
           do let this :: [Int]
this = Int
b Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
as
                  n :: Int
n = S -> Int
next_scc S
newS
              (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\Int
i -> STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
ixes Int
i Int
n) [Int]
this
              S -> ST s S
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return S { stack :: [Int]
stack = [Int]
bs
                       , num :: Int
num = S -> Int
num S
newS
                       , sccs :: [(Int, [Int])]
sccs = (Int
n,[Int]
this) (Int, [Int]) -> [(Int, [Int])] -> [(Int, [Int])]
forall a. a -> [a] -> [a]
: S -> [(Int, [Int])]
sccs S
newS
                       , next_scc :: Int
next_scc = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
                       }
         ([Int], [Int])
_ -> [Char] -> ST s S
forall a. HasCallStack => [Char] -> a
error ([Char]
"bug in scc---vertex not on the stack: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
v)

check_adj :: Func s (Vertex -> [Vertex] -> ST s S)
check_adj :: forall s. Func s (Int -> [Int] -> ST s S)
check_adj Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
st Int
v (Int
v':[Int]
vs) =
  do i <- STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
ixes Int
v'
     case () of
       ()
_ | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 ->
             do newS <- Func s (Int -> ST s S)
forall s. Func s (Int -> ST s S)
from_root Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
st Int
v'
                new_low <- min `fmap` readArray lows v `ap` readArray lows v'
                writeArray lows v new_low
                check_adj g ixes lows newS v vs
         | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 ->
             do j <- STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
lows Int
v
                writeArray lows v (min j (negate i))
                check_adj g ixes lows st v vs
         | Bool
otherwise -> Func s (Int -> [Int] -> ST s S)
forall s. Func s (Int -> [Int] -> ST s S)
check_adj Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
st Int
v [Int]
vs
check_adj Graph
_ STUArray s Int Int
_ STUArray s Int Int
_ S
st Int
_ [] = S -> ST s S
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return S
st