1
votes

I'm trying to understand how to memoize functions in Haskell over arguments of various datatypes. I have implemented the tabulate and apply functions for the Tree type as found in Ralf Hinze's article "Memo functions, polytypically!"

My implementation is below. My test function counts the number of subtrees in a tree of depth d. Should this function be faster if I memoize the recursive call? It is not: timing both versions on my system gives:

helmholtz:LearningHaskell edechter$ time ./Memo 1 23
Not memoized: # of subtrees for tree of depth 23 is: 25165822

real    0m1.898s
user    0m1.886s
sys 0m0.011s
helmholtz:LearningHaskell edechter$ time ./Memo 0 23
Memoized: # of subtrees for tree of depth 23 is: 25165822

real    0m5.129s
user    0m5.013s
sys 0m0.115s

My code is simple:

-- Memo.hs
import System.Environment

data Tree = Leaf | Fork Tree Tree deriving Show
data TTree v = NTree v (TTree (TTree v)) deriving Show

applyTree :: TTree v -> (Tree -> v)
applyTree (NTree tl tf) Leaf = tl
applyTree (NTree tl tf) (Fork l r) = applyTree (applyTree tf l) r

tabulateTree :: (Tree -> v) -> TTree v
tabulateTree f = NTree (f Leaf) (tabulateTree $ \l
                                     -> tabulateTree $ \r -> f (Fork l r))

numSubTrees :: Tree -> Int
numSubTrees Leaf = 1
numSubTrees (Fork l r ) = 2 + numSubTrees l + numSubTrees r

memo = applyTree . tabulateTree

mkTree d | d == 0 = Leaf
         | otherwise = Fork (mkTree $ d-1) (mkTree $ d-1)

main = do
  args <- getArgs
  let version = read $ head args
      d = read $ args !! 1
      (version_name, out) = if version == 0
                              then ("Memoized", (memo numSubTrees) (mkTree d))
                              else ("Not memoized", numSubTrees (mkTree d))
  putStrLn $ version_name ++ ": # of subtrees for tree of depth "
               ++ show d ++ " is: " ++ show out

UPDATE

I see why my function would not be take advantage of memoization, but I still don't understand how to build a function that does take advantage of this. Based on the fibonaci memoization example here, my attempt looks like:

memofunc :: Tree -> Int
memofunc  = memo f
    where f (Fork l r) = memofunc l + memofunc r
          f (Leaf) = 1

func :: Tree -> Int
func (Leaf) = 1
func (Fork l r) = func l + func r

But this still does not do the right thing:

helmholtz:LearningHaskell edechter$ time ./Memo 0 23
Memoized: # of subtrees for tree of depth 23 is: 8388608

real    0m10.436s
user    0m9.895s
sys 0m0.532s
helmholtz:LearningHaskell edechter$ time ./Memo 1 23
Not memoized: # of subtrees for tree of depth 23 is: 8388608

real    0m1.666s
user    0m1.654s
sys 0m0.011s
2
numSubTrees only makes one pass over the tree, so you can't expect memoization to make any difference there. Did you expect it to be able to notice that the trees generated by mkTree have equal left and right subtrees? (It can't).hammar
Okay, but why can't the memoized version notice that they have equal left and right subtrees? After computing the left subtree won't it look up the right subtree in the table rather than recomputing it?Eyal

2 Answers

4
votes

numSubTrees is a recursive function, and your memo can't peek into the recursion: This means that memo numSubTrees only does a lookup for the first call, while the recursive calls are still using the unmemoized version.

1
votes

Both responders were correct, but here's a more complete response.

There were two errors in my original code. The first, which I corrected in an update, was that my original memoized function was only using the memo table in the first call. The recursive calls were just normal unmemoized function calls.

However, even fixing this error did not lead to speed improvements. This was not because the function was failing to call the memo table, but because there were not enough recursive calls to justify indexing into the table. But if we make the function perform more calls on the same subtrees, we see that memoization leads to improvements.

-- Memo.hs                                                                                                                                                                                                  

import System.Environment                                                                                                                                                                                   

data Tree = Leaf | Fork Tree Tree deriving Show                                                                                                                                                             
data TTree v = NTree v (TTree (TTree v)) deriving Show                                                                                                                                                      

applyTree :: TTree v -> (Tree -> v)                                                                                                                                                                         
applyTree (NTree tl tf) Leaf = tl                                                                                                                                                                           
applyTree (NTree tl tf) (Fork l r) = applyTree (applyTree tf l) r                                                                                                                                           

tabulateTree :: (Tree -> v) -> TTree v                                                                                                                                                                      
tabulateTree f = NTree (f Leaf) (tabulateTree $ \l                                                                                                                                                          
                                     -> tabulateTree $ \r -> f (Fork l r))                                                                                                                                  

memofunc :: Tree -> Int                                                                                                                                                                                     
memofunc t  = (memo func) t                                                                                                                                                                                 
    where func :: Tree -> Int                                                                                                                                                                               
          func (Leaf) = 1                                                                                                                                                                                   
          func (Fork Leaf Leaf) = 1                                                                                                                                                                         
          func (Fork l@(Fork a b) r) = memofunc l + memofunc a + memofunc b                                                                                                                                 
                                       + memofunc r                                                                                                                                                         

func :: Tree -> Int                                                                                                                                                                                         
func (Leaf) = 1                                                                                                                                                                                             
func (Fork Leaf Leaf) = 1                                                                                                                                                                                   
func (Fork l@(Fork a b) r) = func l + func a + func b + func r                                                                                                                                              


memo = applyTree . tabulateTree                                                                                                                                                                             

mkTree d | d == 0 = Leaf                                                                                                                                                                                    
         | otherwise = Fork (mkTree $ d-1) (mkTree $ d-1)                                                                                                                                                   

main = do                                                                                                                                                                                                   
  args <- getArgs                                                                                                                                                                                           
  let version = read $ head args                                                                                                                                                                            
      d = read $ args !! 1                                                                                                                                                                                  
      (version_name, out) = if version == 0                                                                                                                                                                 
                            then ("Memoized", (memofunc) (mkTree d))                                                                                                                                        
                            else ("Not memoized", func (mkTree d))                                                                                                                                          
  putStrLn $ version_name ++ ": function apply to tree of depth "                                                                                                                                           
               ++ show d ++ " is: " ++ show out                                                                                                                                                             

The leads to memoized and unmemoized run times (on balanced trees of depth 23):

helmholtz:LearningHaskell edechter$ time ./Memo 0 21
Memoized: function apply to tree of depth 21 is: 733219840

real    0m2.954s
user    0m2.781s
sys 0m0.162s
helmholtz:LearningHaskell edechter$ time ./Memo 1 21
Not memoized: function apply to tree of depth 21 is: 733219840

real    0m6.334s
user    0m6.304s
sys 0m0.025s