Note: This answer is available as a literate Haskell file at Gist.
I quite enjoyed this exercise. I tried to do it without looking at the answers, and it was worth it. It took me considerable time, but the result is surprisingly close to two of the other answers, as well as to monad-coroutine library. So I guess this is somewhat natural solution to this problem. Without this exercise, I wouldn't understand how monad-coroutine really works.
To add some additional value, I'll explain the steps that eventually led me to the solution.
Recognizing the state monad
Since we're dealing with states, it we look for patterns that can be effectively described by the state monad. In particular, s - s
is isomorphic to s -> (s, ())
, so it could be replaced by State s ()
. And function of type s -> x -> (s, y)
can be flipped to x -> (s -> (s, y))
, which is actually x -> State s y
. This leads us to updated signatures
mutate :: State s () - Pause s ()
step :: Pause s () - State s (Maybe (Pause s ()))
Generalization
Our Pause
monad is currently parametrized by the state. However, now we see that we don't really need the state for anything, nor we use any specifics of the state monad. So we could try to make a more general solution that is parametrized by any monad:
mutate :: (Monad m) = m () -> Pause m ()
yield :: (Monad m) = Pause m ()
step :: (Monad m) = Pause m () -> m (Maybe (Pause m ()))
Also, we could try to make mutate
and step
more general by allowing any kind of value, not just ()
. And by realizing that Maybe a
is isomorphic to Either a ()
we can finally generalize our signatures to
mutate :: (Monad m) = m a -> Pause m a
yield :: (Monad m) = Pause m ()
step :: (Monad m) = Pause m a -> m (Either (Pause m a) a)
so that step
returns the intermediate value of the computation.
Monad transformer
Now, we see that we're actually trying to make a monad from a monad - add some additional functionality. This is what is usually called a monad transformer. Moreover, mutate
's signature is exactly the same as lift from MonadTrans
. Most likely, we're on the right track.
The final monad
The step
function seems to be the most important part of our monad, it defines just what we need. Perhaps, this could be the new data structure? Let's try:
import Control.Monad
import Control.Monad.Cont
import Control.Monad.State
import Control.Monad.Trans
data Pause m a
= Pause { step :: m (Either (Pause m a) a) }
If the Either
part is Right
, it's just a monadic value, without any
suspensions. This leads us how to implement the easist thing - the lift
function from MonadTrans
:
instance MonadTrans Pause where
lift k = Pause (liftM Right k)
and mutate
is simply a specialization:
mutate :: (Monad m) => m () -> Pause m ()
mutate = lift
If the Either
part is Left
, it represents the continued computation after a suspension. So let's create a function for that:
suspend :: (Monad m) => Pause m a -> Pause m a
suspend = Pause . return . Left
Now yield
ing a computation is simple, we just suspend with an empty
computation:
yield :: (Monad m) => Pause m ()
yield = suspend (return ())
Still, we're missing the most important part. The Monad
instance. Let's fix
it. Implementing return
is simple, we just lift the inner monad. Implementing >>=
is a bit trickier. If the original Pause
value was only a simple value (Right y
), then we just wrap f y
as the result. If it is a paused computation that can be continued (Left p
), we recursively descend into it.
instance (Monad m) => Monad (Pause m) where
return x = lift (return x) -- Pause (return (Right x))
(Pause s) >>= f
= Pause $ s >>= \x -> case x of
Right y -> step (f y)
Left p -> return (Left (p >>= f))
Testing
Let's try to make some model function that uses and updates state, yielding
while inside the computation:
test1 :: Int -> Pause (State Int) Int
test1 y = do
x <- lift get
lift $ put (x * 2)
yield
return (y + x)
And a helper function that debugs the monad - prints its intermediate steps to
the console:
debug :: Show s => s -> Pause (State s) a -> IO (s, a)
debug s p = case runState (step p) s of
(Left next, s') -> print s' >> debug s' next
(Right r, s') -> return (s', r)
main :: IO ()
main = do
debug 1000 (test1 1 >>= test1 >>= test1) >>= print
The result is
2000
4000
8000
(8000,7001)
as expected.
Coroutines and monad-coroutine
What we have implemented is a quite general monadic solution that implements Coroutines. Perhaps not surprisingly, someone had the idea before :-), and created the monad-coroutine package. Less surprisingly, it's quite similar to what we created.
The package generalizes the idea even further. The continuing computation is stored inside an arbitrary functor. This allows suspend many variations how to work with suspended computations. For example, to pass a value to the caller of resume (which we called step
), or to wait for a value to be provided to continue, etc.
Cont
instance, I'd think; poke atcallCC
. – geekosaur