いろいろがんばりたいブログ

情報科学科の人がいろいろ書きます。

StateMonadをOCamlで書いた。

というお話。コンパイラを書く時に必要になるっぽいので。

(*simple monad*)
module type MONAD =
  sig
    (* Type for state.It's Covariant*)
    type +'a t
    (* lift up normal value to monaded value.*)
    val return : 'a -> 'a t
    (* bind is the function to apply function
         that take normal value and return monaded value to monaded value*)
    val bind : 'a t -> ('a -> 'b t) -> 'b t
  end

(*extended monad.*)
module type EMONAD =
  sig
    include MONAD
    (* let (>>=) = bind *)
    val (>>=) : 'a t -> ('a -> 'b t) -> 'b t
    (*lift up function to monad.
        let lift f t = bind t (fun x -> return @@ f x)
     *)
    val lift : ('a -> 'b) -> 'a t -> 'b t
  end

(*for STATE_MONAD*)
module type STATE =
  sig
    type t
  end

module type STATE_MONAD =
  functor(State : STATE) ->
         sig
           include EMONAD
           (* receive monad and initial state*)
           val run : 'a t -> State.t -> ('a * State.t)
           (* set state to arg*)
           val put : State.t -> unit t
           (* get state.*)
           val get : State.t t
         end

module StateMonad : STATE_MONAD =
  functor(State : STATE) ->
         struct
             type state = State.t
             (* monad type*)
             type 'a t = state -> ('a * state)
             (* make pair of value and state*)
             let return a = fun s -> (a,s)
             (* return state -> ('a * state)
                at first,apply m(1st arg) to s ,
                apply f to returned value*)
             let bind m f =
               (* get new state and value ,*)
               fun s ->
               match m s with
               | (x,s') -> f x s'
             let (>>=) = bind
             let lift f t = bind t (fun x -> return @@ f x)
             let run m a = m a
             let put s =
               fun _ -> ((),s)
             let get =
               fun s -> (s,s)
           end

(* sample usage *)
module IntStateMonad = StateMonad(
                           struct
                             type t = int
                           end
                         )

type 'a cons_list = Nil | Cons of 'a * 'a cons_list

let cons a lst =
  IntStateMonad.(get >>=
                   fun i ->
                   put (succ i)
                 >>= (fun x
                      -> return (Cons (a ,lst))))
(* it is equal to cons*)
(* let cons0 a lst = *)
(*   IntStateMonad.(bind (bind get (fun x -> put @@ succ x)) (fun x -> return (Cons (a,lst)))) *)


(* (\* リストに一個追加したら、カウンターが1になる *\) *)
(* assert (Cons ("a", Nils), 1) = *)
(*     (IntStateMonad.(run (cons "a" Nils >>= fun s -> return s) 0)) *)

(* (\* リストに2個追加したら、カウンターが2になる *\) *)
(* assert (Cons ("b", Cons ("a", Nil)), 2) = *)
(*     (IntStateMonad.(run (cons "a" Nil >>= cons "b" >>= fun s -> return s) 0)) *)

参考