I have been trying to understand the State Monad. Not so much how it is used, though that is not always easy to find, either. But every discussion I find of the State Monad has basically the same information and there is always something I don't understand.
Take this post, for example. In it the author has the following:
case class State[S, A](run: S => (A, S)) {
...
def flatMap[B](f: A => State[S, B]): State[S, B] =
State(s => {
val (a, t) = run(s)
f(a) run t
})
...
}
I can see that the types line up correctly. However, I don't understand the second run
at all.
Perhaps I am looking at the whole purpose of this monad incorrectly. I got the impression from the HaskellWiki that the State monad was kind of like a state-machine with the run
allowing for transitions (though, in this case, the state-machine doesn't really have fixed state transitions like most state machines). If that is the case then in the above code (a, t)
would represent a single transition. The application of f
would represent a modification of that value and State
(generating a new State object). That leaves me completely confused as to what the second run
is all about. It would appear to be a second 'transition'. But that doesn't make any sense to me.
I can see that calling run
on the resulting State
object produces a new (A, S)
pair which, of course, is required for the types to line up. But I don't really see what this is supposed to be doing.
So, what is really going on here? What is the concept being modeled here?
Edit: 12/22/2015
So, it appears I am not expressing my issue very well. Let me try this.
In the same blog post we see the following code for map
:
def map[B](f: A => B): State[S, B] =
State(s => {
val (a, t) = run(s)
(f(a), t)
})
Obviously there is only a single call to run
here.
The model I have been trying to reconcile is that a call to run
moves the state we are keeping forward by a single state-change. This seems to be the case in map
. However, in flatMap
we have two calls to run
. If my model was correct that would result in 'skipping over' a state change.
To make use of the example @Filppo provided below, the first call to run
would result in returning (1, List(2,3,4,5))
and the second would result in (2, List(3,4,5))
, effectively skipping over the first one. Since, in his example, this was followed immediately by a call to map
, this would have resulted in (Map(a->2, b->3), List(4,5))
.
Apparently that is not what is happening. So my whole model is incorrect. What is the correct way to reason about this?
2nd Edit: 12/22/2015
I just tried doing what I said in the REPL. And my instincts were correct which leaves me even more confused.
scala> val v = State(head[Int]).flatMap { a => State(head[Int]) }
v: State[List[Int],Int] = State(<function1>
scala> v.run(List(1,2,3,4,5))
res2: (Int, List[Int]) = (2,List(3, 4, 5))
So, this implementation of flatMap
does skip over a state. Yet when I run @Filippo's example I get the same answer he does. What is really happening here?
State Monad
, there is also a video online: youtube.com/watch?v=Jg3Uv_YWJqI – Filippo Vitale