{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE UndecidableInstances  #-}

module Snap.Snaplet.Internal.RST where

import           Control.Applicative         (Alternative (..),
                                              Applicative (..))
import           Control.Monad
import           Control.Monad.Base          (MonadBase (..))
import qualified Control.Monad.Fail as Fail
import           Control.Monad.Reader        (MonadReader (..))
import           Control.Monad.State.Class   (MonadState (..))
import           Control.Monad.Trans         (MonadIO (..), MonadTrans (..))
import           Control.Monad.Trans.Control (ComposeSt, MonadBaseControl (..),
                                              MonadTransControl (..),
                                              defaultLiftBaseWith,
                                              defaultRestoreM)
import           Snap.Core                   (MonadSnap (..))


------------------------------------------------------------------------------
-- like RWST, but no writer to bog things down. Also assured strict, inlined
-- monad bind, etc
newtype RST r s m a = RST { forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST :: r -> s -> m (a, s) }


evalRST :: Monad m => RST r s m a -> r -> s -> m a
evalRST :: forall (m :: * -> *) r s a. Monad m => RST r s m a -> r -> s -> m a
evalRST RST r s m a
m r
r s
s = do
    (a
a,s
_) <- RST r s m a -> r -> s -> m (a, s)
forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s
    a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
{-# INLINE evalRST #-}


execRST :: Monad m => RST r s m a -> r -> s -> m s
execRST :: forall (m :: * -> *) r s a. Monad m => RST r s m a -> r -> s -> m s
execRST RST r s m a
m r
r s
s = do
    (a
_,!s
s') <- RST r s m a -> r -> s -> m (a, s)
forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s
    s -> m s
forall (m :: * -> *) a. Monad m => a -> m a
return s
s'
{-# INLINE execRST #-}


withRST :: Monad m => (r' -> r) -> RST r s m a -> RST r' s m a
withRST :: forall (m :: * -> *) r' r s a.
Monad m =>
(r' -> r) -> RST r s m a -> RST r' s m a
withRST r' -> r
f RST r s m a
m = (r' -> s -> m (a, s)) -> RST r' s m a
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r' -> s -> m (a, s)) -> RST r' s m a)
-> (r' -> s -> m (a, s)) -> RST r' s m a
forall a b. (a -> b) -> a -> b
$ \r'
r' s
s -> RST r s m a -> r -> s -> m (a, s)
forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m (r' -> r
f r'
r') s
s
{-# INLINE withRST #-}


instance (Monad m) => MonadReader r (RST r s m) where
    ask :: RST r s m r
ask = (r -> s -> m (r, s)) -> RST r s m r
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m (r, s)) -> RST r s m r)
-> (r -> s -> m (r, s)) -> RST r s m r
forall a b. (a -> b) -> a -> b
$ \r
r s
s -> (r, s) -> m (r, s)
forall (m :: * -> *) a. Monad m => a -> m a
return (r
r,s
s)
    local :: forall a. (r -> r) -> RST r s m a -> RST r s m a
local r -> r
f RST r s m a
m = (r -> s -> m (a, s)) -> RST r s m a
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m (a, s)) -> RST r s m a)
-> (r -> s -> m (a, s)) -> RST r s m a
forall a b. (a -> b) -> a -> b
$ \r
r s
s -> RST r s m a -> r -> s -> m (a, s)
forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m (r -> r
f r
r) s
s


instance (Functor m) => Functor (RST r s m) where
    fmap :: forall a b. (a -> b) -> RST r s m a -> RST r s m b
fmap a -> b
f RST r s m a
m = (r -> s -> m (b, s)) -> RST r s m b
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m (b, s)) -> RST r s m b)
-> (r -> s -> m (b, s)) -> RST r s m b
forall a b. (a -> b) -> a -> b
$ \r
r s
s -> ((a, s) -> (b, s)) -> m (a, s) -> m (b, s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(a
a,s
s') -> (a -> b
f a
a, s
s')) (m (a, s) -> m (b, s)) -> m (a, s) -> m (b, s)
forall a b. (a -> b) -> a -> b
$ RST r s m a -> r -> s -> m (a, s)
forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s


instance (Functor m, Monad m) => Applicative (RST r s m) where
    pure :: forall a. a -> RST r s m a
pure = a -> RST r s m a
forall (m :: * -> *) a. Monad m => a -> m a
return
    <*> :: forall a b. RST r s m (a -> b) -> RST r s m a -> RST r s m b
(<*>) = RST r s m (a -> b) -> RST r s m a -> RST r s m b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap


instance (Functor m, MonadPlus m) => Alternative (RST r s m) where
    empty :: forall a. RST r s m a
empty = RST r s m a
forall (m :: * -> *) a. MonadPlus m => m a
mzero
    <|> :: forall a. RST r s m a -> RST r s m a -> RST r s m a
(<|>) = RST r s m a -> RST r s m a -> RST r s m a
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
mplus


instance (Monad m) => MonadState s (RST r s m) where
    get :: RST r s m s
get   = (r -> s -> m (s, s)) -> RST r s m s
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m (s, s)) -> RST r s m s)
-> (r -> s -> m (s, s)) -> RST r s m s
forall a b. (a -> b) -> a -> b
$ \r
_ s
s -> (s, s) -> m (s, s)
forall (m :: * -> *) a. Monad m => a -> m a
return (s
s,s
s)
    put :: s -> RST r s m ()
put s
x = (r -> s -> m ((), s)) -> RST r s m ()
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m ((), s)) -> RST r s m ())
-> (r -> s -> m ((), s)) -> RST r s m ()
forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> ((), s) -> m ((), s)
forall (m :: * -> *) a. Monad m => a -> m a
return ((),s
x)


mapRST :: (m (a, s) -> n (b, s)) -> RST r s m a -> RST r s n b
mapRST :: forall (m :: * -> *) a s (n :: * -> *) b r.
(m (a, s) -> n (b, s)) -> RST r s m a -> RST r s n b
mapRST m (a, s) -> n (b, s)
f RST r s m a
m = (r -> s -> n (b, s)) -> RST r s n b
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> n (b, s)) -> RST r s n b)
-> (r -> s -> n (b, s)) -> RST r s n b
forall a b. (a -> b) -> a -> b
$ \r
r s
s -> m (a, s) -> n (b, s)
f (RST r s m a -> r -> s -> m (a, s)
forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s)


instance (MonadSnap m) => MonadSnap (RST r s m) where
    liftSnap :: forall a. Snap a -> RST r s m a
liftSnap Snap a
s = m a -> RST r s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> RST r s m a) -> m a -> RST r s m a
forall a b. (a -> b) -> a -> b
$ Snap a -> m a
forall (m :: * -> *) a. MonadSnap m => Snap a -> m a
liftSnap Snap a
s

rwsBind :: Monad m =>
           RST r s m a
        -> (a -> RST r s m b)
        -> RST r s m b
rwsBind :: forall (m :: * -> *) r s a b.
Monad m =>
RST r s m a -> (a -> RST r s m b) -> RST r s m b
rwsBind RST r s m a
m a -> RST r s m b
f = (r -> s -> m (b, s)) -> RST r s m b
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST r -> s -> m (b, s)
go
  where
    go :: r -> s -> m (b, s)
go r
r !s
s = do
        (a
a, !s
s')  <- RST r s m a -> r -> s -> m (a, s)
forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s
        RST r s m b -> r -> s -> m (b, s)
forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST (a -> RST r s m b
f a
a) r
r s
s'
{-# INLINE rwsBind #-}

instance (Monad m) => Monad (RST r s m) where
    return :: forall a. a -> RST r s m a
return a
a = (r -> s -> m (a, s)) -> RST r s m a
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m (a, s)) -> RST r s m a)
-> (r -> s -> m (a, s)) -> RST r s m a
forall a b. (a -> b) -> a -> b
$ \r
_ s
s -> (a, s) -> m (a, s)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, s
s)
    >>= :: forall a b. RST r s m a -> (a -> RST r s m b) -> RST r s m b
(>>=)    = RST r s m a -> (a -> RST r s m b) -> RST r s m b
forall (m :: * -> *) r s a b.
Monad m =>
RST r s m a -> (a -> RST r s m b) -> RST r s m b
rwsBind
#if !MIN_VERSION_base(4,13,0)
    fail msg = RST $ \_ _ -> fail msg
#endif

instance Fail.MonadFail m => Fail.MonadFail (RST r s m) where
    fail :: forall a. String -> RST r s m a
fail String
msg = (r -> s -> m (a, s)) -> RST r s m a
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m (a, s)) -> RST r s m a)
-> (r -> s -> m (a, s)) -> RST r s m a
forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> String -> m (a, s)
forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail String
msg

instance (MonadPlus m) => MonadPlus (RST r s m) where
    mzero :: forall a. RST r s m a
mzero       = (r -> s -> m (a, s)) -> RST r s m a
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m (a, s)) -> RST r s m a)
-> (r -> s -> m (a, s)) -> RST r s m a
forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> m (a, s)
forall (m :: * -> *) a. MonadPlus m => m a
mzero
    RST r s m a
m mplus :: forall a. RST r s m a -> RST r s m a -> RST r s m a
`mplus` RST r s m a
n = (r -> s -> m (a, s)) -> RST r s m a
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m (a, s)) -> RST r s m a)
-> (r -> s -> m (a, s)) -> RST r s m a
forall a b. (a -> b) -> a -> b
$ \r
r s
s -> RST r s m a -> r -> s -> m (a, s)
forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s m (a, s) -> m (a, s) -> m (a, s)
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` RST r s m a -> r -> s -> m (a, s)
forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
n r
r s
s


instance (MonadIO m) => MonadIO (RST r s m) where
    liftIO :: forall a. IO a -> RST r s m a
liftIO = m a -> RST r s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> RST r s m a) -> (IO a -> m a) -> IO a -> RST r s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO


instance MonadTrans (RST r s) where
    lift :: forall (m :: * -> *) a. Monad m => m a -> RST r s m a
lift m a
m = (r -> s -> m (a, s)) -> RST r s m a
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m (a, s)) -> RST r s m a)
-> (r -> s -> m (a, s)) -> RST r s m a
forall a b. (a -> b) -> a -> b
$ \r
_ s
s -> do
        a
a <- m a
m
        (a, s) -> m (a, s)
forall (m :: * -> *) a. Monad m => a -> m a
return ((a, s) -> m (a, s)) -> (a, s) -> m (a, s)
forall a b. (a -> b) -> a -> b
$ s
s s -> (a, s) -> (a, s)
`seq` (a
a, s
s)


instance MonadBase b m => MonadBase b (RST r s m) where
    liftBase :: forall α. b α -> RST r s m α
liftBase = m α -> RST r s m α
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m α -> RST r s m α) -> (b α -> m α) -> b α -> RST r s m α
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b α -> m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase


instance MonadBaseControl b m => MonadBaseControl b (RST r s m) where
     type StM (RST r s m) a = ComposeSt (RST r s) m a
     liftBaseWith :: forall a. (RunInBase (RST r s m) b -> b a) -> RST r s m a
liftBaseWith = (RunInBase (RST r s m) b -> b a) -> RST r s m a
forall (t :: (* -> *) -> * -> *) (b :: * -> *) (m :: * -> *) a.
(MonadTransControl t, MonadBaseControl b m) =>
(RunInBaseDefault t m b -> b a) -> t m a
defaultLiftBaseWith
     restoreM :: forall a. StM (RST r s m) a -> RST r s m a
restoreM = StM (RST r s m) a -> RST r s m a
forall (t :: (* -> *) -> * -> *) (b :: * -> *) (m :: * -> *) a.
(MonadTransControl t, MonadBaseControl b m) =>
ComposeSt t m a -> t m a
defaultRestoreM
     {-# INLINE liftBaseWith #-}
     {-# INLINE restoreM #-}


instance MonadTransControl (RST r s) where
    type StT (RST r s) a = (a, s)
    liftWith :: forall (m :: * -> *) a.
Monad m =>
(Run (RST r s) -> m a) -> RST r s m a
liftWith Run (RST r s) -> m a
f = (r -> s -> m (a, s)) -> RST r s m a
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m (a, s)) -> RST r s m a)
-> (r -> s -> m (a, s)) -> RST r s m a
forall a b. (a -> b) -> a -> b
$ \r
r s
s -> do
        a
res <- Run (RST r s) -> m a
f (Run (RST r s) -> m a) -> Run (RST r s) -> m a
forall a b. (a -> b) -> a -> b
$ \(RST r -> s -> n (b, s)
g) -> r -> s -> n (b, s)
g r
r s
s
        (a, s) -> m (a, s)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
res, s
s)
    restoreT :: forall (m :: * -> *) a.
Monad m =>
m (StT (RST r s) a) -> RST r s m a
restoreT m (StT (RST r s) a)
k = (r -> s -> m (a, s)) -> RST r s m a
forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST ((r -> s -> m (a, s)) -> RST r s m a)
-> (r -> s -> m (a, s)) -> RST r s m a
forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> m (a, s)
m (StT (RST r s) a)
k
    {-# INLINE liftWith #-}
    {-# INLINE restoreT #-}