関数の自動メモ化

groovy 1.8 では、memoize によってクロージャのメモ化が出来るようになったけれど、scala だってできるもん、という負け惜しみエントリ。

普通の自動メモ化

サクっと作ったものを(1〜5引数対応)をGistに上げたので簡単に紹介。

実装のポイントは単純で、下記の通り。

  // 1変数関数をメモ化する(2変数以上は tupled/untupledで対応)
  def memoize[ArgT, RetT](f: ArgT => RetT): ArgT => RetT = {
    val memo = scala.collection.mutable.Map.empty[ArgT, RetT]
    arg => memo.getOrElseUpdate(arg, f(arg))
  }

こんな風にして使う。

object Main extends Application {
  val fib: Int => Long = memoize {
    case 0 => 0
    case 1 => 1
    case n => fib(n - 1) + fib(n - 2)
  }
  println(fib(40))
}

再帰関数をメモ化する場合には、要注意な点がある。もとの関数で再帰呼び出しをする際には、メモ化された後の関数を呼び出さないとダメな点。そうしないと、再帰の呼出しがメモ化されない。上記コード例のように、あくまで、メモ化された後の関数実体が束縛される変数名をメモ化される前の関数から参照しなくてはならない。

下記の例は、ダメな例。

  // これはダメ
  val fib: Int => Long = {
    case 0 => 0
    case 1 => 1
    case n => fib(n - 1) + fib(n - 2)
  }
  val fibMemo = memoized(fib)

そう、まさに Groovy の memoize と同じ問題が起きる。

2011-05-08 追記

Groovy での再帰対応については、id:dev68 さんのエントリが参考になります。
エントリ全体としては、Groovy の Meta Object Protocol (MOP) の紹介と 1.8 での変更点に関するもので、非常に参考になりました。

再帰対応などの方法を検討(失敗談など)

普通は、この先って、いつキャシュを破棄するのかとかそういうところを詰めていくんじゃないのかなー、(例えば、外からキャッシュ解放操作を受け付けるとか、MRUキャッシュに置き換えるとか、WeakReferenceでドツボにはまるとか etc.)と思いつつ、あえて、「厳密に言えば再帰に対応していない」点を突き詰めてみる。前述のように、メモ化後のシンボルを参照すれば問題ないとは言え、おいらが許さない。

そこで、厳密に再帰に対応、かつ、動的に自動メモするにはどうすれば良いのか、検討してみた。なお(1)動的とは、関数を値として受け取るという意味で、(2)厳密に再帰対応とは、再帰を含んでいてもメモ化後に束縛されるシンボルを参照せずにメモ化できるという意味。

(1) だけならば前述の方法で、(2)だけならば後述の trait との mixin で対応可能だし、そもそも関数の形を変えていいのならば、不動点コンビネータを用いるのが定石。だけれど、関数の形を変えずに(1)と(2)を融合するのは、それなりの黒魔術が必要になる。

この先は失敗談がてら、簡単に何を検討したのかメモしておく。

不動点コンビネータ

まず、λ再帰(自分の名前を参照せずに再帰する方法)では、不動点コンビネータを使うのが定石。しかし、当然関数の形は変わってしまうので、実行中に関数を受け取ってメモ化するという点では、今回の目的にはそぐわない。

そぐわないとは分かっているけれど、あえて泥沼に足を踏み入れてみる。エイヤー!

まず 不動点コンビネータの基本である Y combinator を scala で書くと、下記のような感じ。(既に scala 実装はあるのだけれど、そこはあえて我流で)

  type F[A, B] = ((A => B), A) => B
  type Y[A, B] = (F[A, B], A) => B
  def Y[A, B](f: F[A, B], x: A): B = f((x1: A) => Y(f, x1), x)

書いてみて分かる 型つきのよさ。型がどう変化するかを type で記述できるという scala の良さを見なおした瞬間でもある。まあ、私は計算機科学に関する高等教育を受けていないので Y combinator に馴染みが薄く、先に型を書かないと飲み込めないというだけのことかもしれないけれどね。

以下、ねっとり(not サクっと)実装してみる。

object TestMemoRecFP {
  // memoization using Y-combinator
  type F[A, B] = ((A => B), A) => B
  type Y[A, B] = (F[A, B], A) => B
  def Y[A, B](f: F[A, B], x: A): B = f((x1: A) => Y(f, x1), x)
  def memoize[A, B](f: F[A, B]): F[A, B] = {
    val cache = collection.mutable.Map.empty[A, B]
    (fx, arg) => cache.getOrElseUpdate(arg, f(fx, arg))
  }
  
  // Factorial 
  val factorial: F[Int, Long] = (f, arg) => arg match {
    case 0 => 1
    case x if x > 0 => x.toLong * f(x - 1)
    case x if x < 0 => throw new IllegalArgumentException(
        "requirement failed: argument of factorial must be > 0")
  }
  val factoricalWithMemo = memoize(factorial)
  
  // Fibonacci sequence extended to negative index
  // What is the definition of fibonacci number with negative indices?
  // @see http://en.wikipedia.org/wiki/Fibonacci_number
  val fibonacci: F[Int, BigInt] = (f, arg) => arg match {
    case 0 => 0
    case 1 => 1
    case x if (x > 0) => f(x - 1) + f(x - 2)
    case x if (x % 2 == 0) => -f(-x)  // x < 0 && x is even
    case x => f(-x)                   // x < 0 && x is odd
  }
  val fibonacciWithMemo = memoize(fibonacci)
  
  def exampleOfMemoizationWithYcomb = {
    println(Y(fibonacciWithMemo, 10))
    println(Y(fibonacci, 10))
    
    println(Y(fibonacciWithMemo, 34))
    println(Y(fibonacci, 34))
    
    println(Y(fibonacciWithMemo, 100))
    //println(Y(fibonacci, 100))  // it takes too long as if it never ends
  }
  
  def main(args : Array[String]) {
    exampleOfMemoizationWithYcomb
  }
}
メモ化する trait と mixin

コップ本の Queue をデコレートする mixin の解説部分(Doubling とか、そんな trait をつくるところ)を見ながら、必死で作ってみた。(具体的な参照箇所は、原著第2版 12.5 The Doubling stackable modification trait を見たんだけれど、日本語のコップ本も、無料の原著第1版 12.5も同内容)

やる前から分かっていることだが、再帰には対応するんだけれど、関数を値として受けとってメモ化することはできない。(関数のクラスを受け取って mixin するか、匿名クラスをその場で定義しなくてはならない)。つまり、匿名関数リテラルの値をメモ化したりできないので、本エントリの最初に掲げたやりかたよりも、実用上すごく不便。

trait Memoized1[T1, R] extends Function1[T1, R] {
  val cache = collection.mutable.Map.empty[T1, R]
  abstract override def apply (v1: T1): R =
    cache.getOrElseUpdate(v1, super.apply(v1))
}

object TestMemoWithTrait {
  def testWithTraits {

    class FibonacciClass extends Function1[Int, BigInt] {
      override def apply(n: Int) = {
        println("Function called: "+this.getClass.toString+"; arg n = "+n)
        n match {
          case x if (x < 0) =>
            throw new IllegalArgumentException("The given argument must be >= 0")
          case 0 => 0
          case 1 => 1
          case x if (x >= 2) => apply(x - 1) + apply(x - 2)
        }        
      }
    }
    object fib extends FibonacciClass
    object fibMemo extends FibonacciClass with Memoized1[Int, BigInt]
    
    val f = new Memoized1[Int, Int] {
      def apply(n: Int): Int = {
        println("Function called: "+this.getClass.toString+"; arg n = "+n)
        n match {
          case x if x < 0 => 
            throw new IllegalArgumentException("The given argument must be >= 0")
          case 0 => 0
          case 1 => 1
          case x if (x >= 2) => apply(x - 1) + apply(x - 2)
        }
      }
    }
    
    println(fib(5))
    println(fibMemo(5))
    println(f(5))
    
  }
  def main(args : Array[String]) {
    testWithTraits
  }
}
関数リテラルを受け取りつつ、その匿名クラスを Manifest でゴニョ

先程の mixin の方法は、再帰に対応する正当な方法(他の方法としては、コンパイラプラグインか Byte code engineering みたいな黒魔術になるはず)なんだけれど、mixin はあくまでクラスに適用するものだから、クラスがわからないことにはどうにもならない。

そこで、関数の値を受け取りつつそいつ自体は完全に無視して、implicit なパラメータでその関数値のクラス(通常は関数リテラルを書いたときに自動生成される匿名クラス)を受け取って、そいつと mixin してあげようと。そんな無茶な考えが下記。ちなみに scala 2.9.0 RC1 です。

import scala.reflect.Manifest
import scala.tools.nsc.interpreter.IMain
import scala.tools.nsc.Settings
import scala.tools.nsc.settings._
import scala.tools.nsc.util.BatchSourceFile
import scala.tools.util.PathResolver

object DynamicMemoizer 
  extends java.lang.ClassLoader(getClass.getClassLoader) {
  private val id = Iterator.from(1)
  def createUniqueId = synchronized { "DynamicMemoizerKlass" + id.next }
  
  def apply[A, R](func: A => R)(implicit a: Manifest[A], r: Manifest[R]) = {
    val id = createUniqueId
    val classDef = "class %s extends %s with Memoizing[%s, %s]".
      format(id, func.getClass.getName, a.toString, r.toString)
      
    println(classDef)
    
    val settings = new Settings(println)
    settings.usejavacp.value = true
    val interpreter = new IMain(settings)
    interpreter.setContextClassLoader

    interpreter.compileSources(new BatchSourceFile("<anon>", classDef))

    val bytes = interpreter.classLoader.findBytesForClassName(id)

    val clazz = defineClass(id, bytes, 0, bytes.length).asInstanceOf[Class[(A => R) with Memoizing[A, R]]]
    
    clazz.newInstance
  }
}

trait Memoizing[T, R] extends Function1[T, R] {
  val memo = scala.collection.mutable.Map.empty[T, R]
  abstract override def apply(arg: T): R =
    memo.getOrElseUpdate(arg, super.apply(arg))
}

object DynamicMemoMain extends Application {
  
  // 比較用
  class fibClass extends (Int => Long) {
    def apply(n: Int): Long = n match {
      case 0 => 0
      case 1 => 1
      case _ => apply(n - 2) + apply(n - 1)
    }    
  }
  object fibStaticMixin extends fibClass with Memoizing[Int, Long]  
  
  val fib: Int => Long = {
    case 0 => 0
    case 1 => 1
    case n => fib(n - 1) + fib(n - 2)
  }
  // 全てはこれのために、動的自動メモ化発動!
  val fibDynamicMixin = DynamicMemoizer(fib)
  
  println(fibStaticMixin(40))
  println(fibDynamicMixin(40))
}

なんじゃこりゃーレベルOTL

AOP (メモ化とは、アスペクト横断な関心事なんですキリッ!)

なんじゃこりゃーついでに、spring aop とか使ってみる。
なんというか、メモ化って、関数呼び出しごとにキャッシュ参照するっていうアスペクト横断的関心事なんですよ、という発想。もちろん、厳密な意味では再帰呼び出しに対応しない。(最初に示した方法と同様、メモ化後のシンボルを参照する必要がある)

import org.aopalliance.intercept.{MethodInterceptor, MethodInvocation}
import org.springframework.aop.framework.ProxyFactory

// Memoization with `MethodInterceptor`
class Memoizer extends MethodInterceptor {
  import scala.collection.mutable.WrappedArray
  private val cacheMap = collection.mutable.Map.empty[WrappedArray[AnyRef], AnyRef]
  def invoke(invocation: MethodInvocation): AnyRef =
    cacheMap.getOrElseUpdate(invocation.getArguments, invocation.proceed)
}

object TestMemo {
  def memoize[T](funcObj: T): T = {
    val pf = new ProxyFactory
    pf.setTarget(funcObj)
    pf.addAdvice(new Memoizer)
    pf.getProxy.asInstanceOf[T]
  }
  
  def exampleOfMemoizationWithAOP = {
    def sqr(x: Int): Int = {
      println("method invoked!")
      x * x
    }
    val sqrMemo = memoize(sqr _)
    
    println("sqr first call:")
    println(1 to 3 map sqr)
    println("sqr second call:")
    println(1 to 3 map sqr)
    
    println("sqrMemo first call")
    println(1 to 3 map sqrMemo)
    println("sqrMemo with memo second call")
    println(1 to 3 map sqrMemo)
  }
  
  def main(args : Array[String]) {
    exampleOfMemoizationWithAOP
  }
}
おわりに

まあ、せっかくなので、失敗談も含めて晒してみました。自分の無能さ加減(単に日本語が下手くそというだけでなく、実装力も低い)の記録ですね。後日に振り返って、「あー、おいらも成長したなー」を実感するための材料に化ける?

あと未検討な手法は、BCE(Byte code engineering) と コンパイラプラグインがあるのだけれど、BCEの場合は相互再帰とかもっと複雑な再帰をどうするかという問題があるのでとても大変。とすると、コンパイラプラグインが有望なのかな。