Memoização em Haskell?

136

Quaisquer dicas sobre como resolver com eficiência a seguinte função no Haskell, para grandes números (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

Eu vi exemplos de memorização em Haskell para resolver números de fibonacci, que envolviam computar (preguiçosamente) todos os números de fibonacci até o n necessário. Mas, neste caso, para um dado n, precisamos apenas calcular muito poucos resultados intermediários.

obrigado

Angel de Vicente
fonte
110
Apenas no sentido de que é um trabalho que eu estou fazendo em casa :-)
Angel de Vicente

Respostas:

256

Podemos fazer isso de maneira muito eficiente, criando uma estrutura que podemos indexar em tempo sublinear.

Mas primeiro,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Vamos definir f, mas faça com que use 'recursão aberta' em vez de se chamar diretamente.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

Você pode obter um não-personalizado fusandofix f

Isso permitirá que você teste fo que você quer dizer com pequenos valores fchamando, por exemplo:fix f 123 = 144

Podemos memorizar isso definindo:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

Isso tem um desempenho razoavelmente bom e substitui o que levaria tempo O (n ^ 3) por algo que memorizasse os resultados intermediários.

Mas ainda leva tempo linear apenas para indexar e encontrar a resposta memorizada mf. Isso significa que resultados como:

*Main Data.List> faster_f 123801
248604

são toleráveis, mas o resultado não é muito melhor que isso. Nós podemos fazer melhor!

Primeiro, vamos definir uma árvore infinita:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

E então definiremos uma maneira de indexá-lo, para que possamos encontrar um nó com índice nno tempo O (log n) :

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... e podemos achar uma árvore cheia de números naturais conveniente, para que não tenhamos que mexer com esses índices:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Como podemos indexar, você pode converter uma árvore em uma lista:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

Você pode verificar o trabalho até agora, verificando se isso toList natsoferece[0..]

Agora,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

funciona exatamente como na lista acima, mas, em vez de levar um tempo linear para encontrar cada nó, pode persegui-lo em tempo logarítmico.

O resultado é consideravelmente mais rápido:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

Na verdade, é muito mais rápido que você possa passar e substituir Intcom Integeracima e obter ridiculamente grandes respostas quase que instantaneamente

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
Edward KMETT
fonte
3
Eu tentei esse código e, curiosamente, f_faster parecia ser mais lento que f. Eu acho que essas referências realmente atrasaram as coisas. A definição de nats e índice parecia bastante misteriosa para mim, então adicionei minha própria resposta que pode tornar as coisas mais claras.
Pitarou
5
O caso de lista infinito deve lidar com uma lista vinculada de 111111111 itens. O caso em árvore está lidando com log n * o número de nós atingidos.
Edward KMETT
2
ou seja, a versão da lista precisa criar thunks para todos os nós da lista, enquanto a versão em árvore evita a criação de muitos deles.
Tom Ellis
7
Eu sei que esta é uma postagem bastante antiga, mas não deve f_treeser definida em uma wherecláusula para evitar salvar caminhos desnecessários na árvore nas chamadas?
Dfeuer
17
O motivo para colocá-lo em um CAF é que você pode obter memorização nas chamadas. Se eu tivesse uma ligação cara, estava memorizando, provavelmente a deixaria em um CAF, daí a técnica mostrada aqui. Em uma aplicação real, há uma troca entre os benefícios e os custos da memorização permanente, é claro. Embora, dada a pergunta sobre como obter a memorização, acho que seria enganoso responder com uma técnica que evita deliberadamente a memorização em chamadas, e se nada mais, esse comentário aqui irá apontar as pessoas para o fato de que há sutilezas. ;)
Edward KMETT
17

A resposta de Edward é uma jóia tão maravilhosa que eu a dupliquei e forneci implementações memoListe memoTreecombinadores que memorizam uma função de forma aberta-recursiva.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
Tom Ellis
fonte
12

Não é a maneira mais eficiente, mas memoriza:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

ao solicitar f !! 144, é verificado se f !! 143existe, mas seu valor exato não é calculado. Ainda está definido como resultado desconhecido de um cálculo. Os únicos valores exatos calculados são os necessários.

Então, inicialmente, na medida em que foi calculado, o programa não sabe nada.

f = .... 

Quando fazemos a solicitação f !! 12, ela começa a fazer alguma correspondência de padrões:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Agora começa a calcular

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

Isso recursivamente faz outra demanda em f, então calculamos

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

Agora podemos voltar alguns

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

O que significa que o programa agora sabe:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Continuando a chegar:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

O que significa que o programa agora sabe:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Agora continuamos com nosso cálculo de f!!6:

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

O que significa que o programa agora sabe:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Agora continuamos com nosso cálculo de f!!12:

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

O que significa que o programa agora sabe:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

Portanto, o cálculo é feito com preguiça. O programa sabe que f !! 8existe algum valor para , é igual a g 8, mas não tem idéia do que g 8é.

rampion
fonte
Obrigado por este. Como você criaria e usaria um espaço de solução bidimensional? Isso seria uma lista de listas? eg n m = (something with) f!!a!!b
vikingsteve
1
Claro, você poderia. Para uma solução real, porém, eu provavelmente usar uma biblioteca memoization, como memocombinators
rampion
É O (n ^ 2) infelizmente.
precisa saber é o seguinte
8

Este é um adendo à excelente resposta de Edward Kmett.

Quando tentei o código dele, as definições natse indexpareciam bastante misteriosas, então escrevi uma versão alternativa que achei mais fácil de entender.

Eu defino indexe natsem termos de index'e nats'.

index' t né definido sobre o intervalo [1..]. (Lembre-se de que index té definido acima do intervalo [0..].) Ele funciona pesquisando a árvore tratando ncomo uma sequência de bits e lendo os bits ao contrário. Se o bit for 1, é necessário o ramo direito. Se o bit estiver 0, ele pega o ramo esquerdo. Para quando atinge o último bit (que deve ser a 1).

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

Assim como natsé definido para, indexpara que index nats n == nsempre seja verdadeiro, nats'é definido para index'.

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

Agora, natse indexsão simples nats'e index'com os valores alterados por 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
Pitarou
fonte
Obrigado. Estou memorizando uma função multivariada, e isso realmente me ajudou a descobrir o que índice e nats estavam realmente fazendo.
Kittsil
8

Conforme declarado na resposta de Edward Kmett, para acelerar as coisas, você precisa armazenar em cache cálculos caros e poder acessá-los rapidamente.

Para manter a função não monádica, a solução de construir uma árvore lenta e infinita, com uma maneira apropriada de indexá-la (como mostrado nas postagens anteriores), cumpre esse objetivo. Se você abandonar a natureza não monádica da função, poderá usar os contêineres associativos padrão disponíveis no Haskell em combinação com as mônadas "semelhantes a estados" (como Estado ou ST).

Embora a principal desvantagem seja a obtenção de uma função não monádica, você não precisa mais indexar a estrutura e pode usar implementações padrão de contêineres associativos.

Para fazer isso, primeiro você precisa reescrever sua função para aceitar qualquer tipo de mônada:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

Para seus testes, você ainda pode definir uma função que não memoriza usando Data.Function.fix, embora seja um pouco mais detalhada:

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

Em seguida, você pode usar a mônada do estado em combinação com o Data.Map para acelerar as coisas:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

Com pequenas alterações, você pode adaptar o código para trabalhar com Data.HashMap:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

Em vez de estruturas de dados persistentes, você também pode tentar estruturas de dados mutáveis ​​(como o Data.HashTable) em combinação com a mônada ST:

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

Comparado à implementação sem nenhuma memorização, qualquer uma dessas implementações permite que, para grandes entradas, obtenha resultados em microssegundos, em vez de esperar vários segundos.

Usando o Critério como referência, pude observar que a implementação com o Data.HashMap realmente teve um desempenho um pouco melhor (cerca de 20%) do que os Data.Map e Data.HashTable para os quais os tempos eram muito semelhantes.

Achei os resultados do benchmark um pouco surpreendentes. Meu sentimento inicial era que o HashTable superaria a implementação do HashMap porque é mutável. Pode haver algum defeito de desempenho oculto nesta última implementação.

Quentin
fonte
2
O GHC faz um ótimo trabalho de otimização em torno de estruturas imutáveis. A intuição de C nem sempre dá certo.
John Tyree
3

Alguns anos depois, observei isso e percebi que havia uma maneira simples de memorizar isso em tempo linear usando zipWithuma função auxiliar:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilate tem a propriedade útil que dilate n xs !! i == xs !! div i n .

Então, supondo que recebamos f (0), isso simplifica a computação para

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

Parecendo muito com a descrição original do problema e fornecendo uma solução linear ( sum $ take n fsserá usado O (n)).

rampion
fonte
2
portanto, é uma solução generativa (corecursiva?) ou de programação dinâmica. Tomando O (1) o tempo para cada valor gerado, como o Fibonacci usual está fazendo. Ótimo! E a solução da EKMETT é como os grandes Fibonacci logarítmicos, alcançando grandes números muito mais rapidamente, pulando muitos dos entres. Isso é certo?
Will Ness
ou talvez seja mais próximo do número de Hamming, com os três indicadores de retorno da sequência que está sendo produzida e as diferentes velocidades de cada um deles avançando ao longo dela. muito bonito.
Will Ness
2

Mais um adendo à resposta de Edward Kmett: um exemplo independente:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

Use-o da seguinte maneira para memorizar uma função com um único número inteiro arg (por exemplo, fibonacci):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

Somente valores para argumentos não negativos serão armazenados em cache.

Para também armazenar em cache valores para argumentos negativos, use memoInt, definido da seguinte maneira:

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

Para armazenar em cache valores para funções com dois argumentos inteiros memoIntInt, use o seguinte:

memoIntInt f = memoInt (\n -> memoInt (f n))
Neal Young
fonte
2

Uma solução sem indexação e não baseada em Edward KMETT.

Eu fatoro subárvores comuns a um pai comum ( f(n/4)é compartilhado entre f(n/2)e f(n/4), e f(n/6)é compartilhado entre f(2)e f(3)). Salvando-os como uma única variável no pai, o cálculo da subárvore é feito uma vez.

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

O código não se estende facilmente a uma função de memorização geral (pelo menos, eu não saberia como fazê-lo), e você realmente precisa pensar em como os subproblemas se sobrepõem, mas a estratégia deve funcionar para vários parâmetros não inteiros gerais . (Pensei em dois parâmetros de string.)

A nota é descartada após cada cálculo. (Mais uma vez, eu estava pensando em dois parâmetros de string.)

Não sei se isso é mais eficiente que as outras respostas. Cada pesquisa é tecnicamente apenas uma ou duas etapas ("Olhe para o seu filho ou filho do seu filho"), mas pode haver muito uso de memória extra.

Edit: Esta solução ainda não está correta. O compartilhamento está incompleto.

Edit: Ele deve compartilhar os sub-filhos corretamente agora, mas percebi que esse problema tem muitos compartilhamentos não triviais: n/2/2/2e n/3/3pode ser o mesmo. O problema não é um bom ajuste para minha estratégia.

leewz
fonte