泛函編程(29)-泛函實用結構:Trampoline-再也不怕StackOverflow

   泛函編程方式其中一個特色就是廣泛地使用遞歸算法,並且有些地方還沒法避免使用遞歸算法。好比說flatMap就是一種推動式的遞歸算法,沒了它就沒法使用for-comprehension,那麼泛函編程也就沒法被稱爲Monadic Programming了。雖然遞歸算法能使代碼更簡潔易明,但同時又以佔用堆棧(stack)方式運做。堆棧是軟件程序有限資源,因此在使用遞歸算法對大型數據源進行運算時系統每每會出現StackOverflow錯誤。若是不想辦法解決遞歸算法帶來的StackOverflow問題,泛函編程模式也就失去了實際應用的意義了。java

針對StackOverflow問題,Scala compiler可以對某些特別的遞歸算法模式進行優化:把遞歸算法轉換成while語句運算,但只限於尾遞歸模式(TCE, Tail Call Elimination),咱們先用例子來了解一下TCE吧:es6

如下是一個右摺疊算法例子:算法

1 def foldR[A,B](as: List[A], b: B, f: (A,B) => B): B = as match { 2     case Nil => b 3     case h :: t => f(h,foldR(t,b,f)) 4 }                                                 //> foldR: [A, B](as: List[A], b: B, f: (A, B) => B)B
5 def add(a: Int, b: Int) = a + b                   //> add: (a: Int, b: Int)Int
6 
7 foldR((1 to 100).toList, 0, add)                  //> res0: Int = 5050
8 foldR((1 to 10000).toList, 0, add)                //> java.lang.StackOverflowError

以上的右摺疊算法中自引用部分不在最尾部,Scala compiler沒法進行TCE,因此處理一個10000元素的List就發生了StackOverflow。編程

再看看左摺疊:數據結構

1 def foldL[A,B](as: List[A], b: B, f: (B,A) => B): B = as match { 2     case Nil => b 3     case h :: t => foldL(t,f(b,h),f) 4 }                                                 //> foldL: [A, B](as: List[A], b: B, f: (B, A) => B)B
5 foldL((1 to 100000).toList, 0, add)               //> res1: Int = 705082704

在這個左摺疊例子裏自引用foldL出如今尾部位置,Scala compiler能夠用TCE來進行while轉換:函數

 1  def foldl2[A,B](as: List[A], b: B,  2                  f: (B,A) => B): B = {  3     var z = b  4     var az = as  5     while (true) {  6  az match {  7         case Nil => return z  8         case x :: xs => {  9           z = f(z, x) 10           az = xs 11  } 12  } 13  } 14  z 15   }

通過轉換後遞歸變成Jump,程序再也不使用堆棧,因此不會出現StackOverflow。優化

但在實際編程中,通通把遞歸算法編寫成尾遞歸是不現實的。有些複雜些的算法是沒法用尾遞歸方式來實現的,加上JVM實現TCE的能力有侷限性,只能對本地(Local)尾遞歸進行優化。this

咱們先看個稍微複雜點的例子:es5

 

1 def even[A](as: List[A]): Boolean = as match { 2     case Nil => true
3     case h :: t => odd(t) 4 }                                                 //> even: [A](as: List[A])Boolean
5 def odd[A](as: List[A]): Boolean = as match { 6     case Nil => false
7     case h :: t => even(t) 8 }                                                 //> odd: [A](as: List[A])Boolean

 

在上面的例子裏even和odd分別爲跨函數的各自的尾遞歸,但Scala compiler沒法進行TCE處理,由於JVM不支持跨函數Jump:spa

 

1 even((1 to 100).toList)                           //> res2: Boolean = true
2 even((1 to 101).toList)                           //> res3: Boolean = false
3 odd((1 to 100).toList)                            //> res4: Boolean = false
4 odd((1 to 101).toList)                            //> res5: Boolean = true
5 even((1 to 10000).toList)                         //> java.lang.StackOverflowError

 

處理10000個元素的List仍是出現了StackOverflowError

咱們能夠經過設計一種數據結構實現以heap交換stack。Trampoline正是專門爲解決StackOverflow問題而設計的數據結構:

 

1 trait Trampoline[+A] { 2  final def runT: A = this match { 3       case Done(a) => a 4       case More(k) => k().runT 5  } 6 } 7 case class Done[+A](a: A) extends Trampoline[A] 8 case class More[+A](k: () => Trampoline[A]) extends Trampoline[A]

 

Trampoline表明一個能夠一步步進行的運算。每步運算都有兩種可能:Done(a),直接完成運算並返回結果a,或者More(k)運算k後進入下一步運算;下一步又有可能存在Done和More兩種狀況。注意Trampoline的runT方法是明顯的尾遞歸,並且runT有final標示,表示Scala能夠進行TCE。

有了Trampoline咱們能夠把even,odd的函數類型換成Trampoline:

1 def even[A](as: List[A]): Trampoline[Boolean] = as match { 2     case Nil => Done(true) 3     case h :: t => More(() => odd(t)) 4 }                                                 //> even: [A](as: List[A])ch13.ex1.Trampoline[Boolean]
5 def odd[A](as: List[A]): Trampoline[Boolean] = as match { 6     case Nil => Done(false) 7     case h :: t => More(() => even(t)) 8 }                                                 //> odd: [A](as: List[A])ch13.ex1.Trampoline[Boolean]

咱們能夠用Trampoline的runT來運算結果:

 

1 even((1 to 10000).toList).runT                    //> res6: Boolean = true
2 even((1 to 10001).toList).runT                    //> res7: Boolean = false
3 odd((1 to 10000).toList).runT                     //> res8: Boolean = false
4 odd((1 to 10001).toList).runT                     //> res9: Boolean = true

 

此次咱們不但獲得了正確結果並且也沒有發生StackOverflow錯誤。就這麼簡單?

咱們再從一個比較實際複雜一點的例子分析。在這個例子中咱們遍歷一個List並維持一個狀態。咱們首先須要State類型:

 

 1 case class State[S,+A](runS: S => (A,S)) {  2 import State._  3     def flatMap[B](f: A => State[S,B]): State[S,B] = State[S,B] {  4         s => {  5             val (a1,s1) = runS(s)  6  f(a1) runS s1  7  }  8  }  9   def map[B](f: A => B): State[S,B] = flatMap( a => unit(f(a))) 10 } 11 object State { 12     def unit[S,A](a: A) = State[S,A] { s => (a,s) } 13     def getState[S]: State[S,S] = State[S,S] { s => (s,s) } 14     def setState[S](s: S): State[S,Unit] = State[S,Unit] { _ => ((),s)} 15 }

 

再用State類型來寫一個對List元素進行序號標註的函數:

 1 def zip[A](as: List[A]): List[(A,Int)] = {  2  as.foldLeft(  3  unit[Int,List[(A,Int)]](List()))(  4       (acc,a) => for {  5         xs <- acc  6         n <- getState[Int]  7         _ <- setState[Int](n + 1)  8  } yield (a,n) :: xs  9     ).runS(0)._1.reverse 10 }                                                 //> zip: [A](as: List[A])List[(A, Int)]

運行一下這個zip函數:

 

1 zip((1 to 10).toList)                             //> res0: List[(Int, Int)] = List((1,0), (2,1), (3,2), (4,3), (5,4), (6,5), (7,6 2                                                   //| ), (8,7), (9,8), (10,9))

 

結果正確。若是針對大型的List呢?

 

1 zip((1 to 10000).toList)                          //> java.lang.StackOverflowError

 

按理來講foldLeft是尾遞歸的,怎麼StackOverflow出現了。這是由於State組件flatMap是一種遞歸算法,也會致使StackOverflow。那麼咱們該如何改善呢?咱們是否是像上面那樣把State轉換動做的結果類型改爲Trampoline就好了呢?

 

 1 case class State[S,A](runS: S => Trampoline[(A,S)]) {  2     def flatMap[B](f: A => State[S,B]): State[S,B] = State[S,B] {  3         s => More(() => {  4             val (a1,s1) = runS(s).runT  5             More(() => f(a1) runS s1)  6  })  7  }  8   def map[B](f: A => B): State[S,B] = flatMap( a => unit(f(a)))  9 } 10 object State { 11     def unit[S,A](a: A) = State[S,A] { s => Done((a,s)) } 12     def getState[S]: State[S,S] = State[S,S] { s => Done((s,s)) } 13     def setState[S](s: S): State[S,Unit] = State[S,Unit] { _ => Done(((),s))} 14 } 15 trait Trampoline[+A] { 16  final def runT: A = this match { 17       case Done(a) => a 18       case More(k) => k().runT 19  } 20 } 21 case class Done[+A](a: A) extends Trampoline[A] 22 case class More[+A](k: () => Trampoline[A]) extends Trampoline[A] 23 
24 def zip[A](as: List[A]): List[(A,Int)] = { 25  as.foldLeft( 26  unit[Int,List[(A,Int)]](List()))( 27       (acc,a) => for { 28         xs <- acc 29         n <- getState[Int] 30         _ <- setState[Int](n + 1) 31  } yield (a,n) :: xs 32     ).runS(0).runT._1.reverse 33 }                                                 //> zip: [A](as: List[A])List[(A, Int)]
34 zip((1 to 10).toList)                             //> res0: List[(Int, Int)] = List((1,0), (2,1), (3,2), (4,3), (5,4), (6,5), (7, 35                                                   //| 6), (8,7), (9,8), (10,9))

 

在這個例子裏咱們把狀態轉換函數 S => (A,S) 變成 S => Trampoline[(A,S)]。而後把其它相關函數類型作了相應調整。運行zip再檢查結果:結果正確。那麼再試試大型List:

1 zip((1 to 10000).toList)                          //> java.lang.StackOverflowError

仍是會出現StackOverflow。此次是由於flatMap中的runT不在尾遞歸位置。那咱們把Trampoline變成Monad看看如何?那咱們就得爲Trampoline增長一個flatMap函數:

 

 1 trait Trampoline[+A] {  2  final def runT: A = this match {  3       case Done(a) => a  4       case More(k) => k().runT  5  }  6   def flatMap[B](f: A => Trampoline[B]): Trampoline[B] = {  7       this match {  8           case Done(a) => f(a)  9           case More(k) => f(runT) 10  } 11  } 12 } 13 case class Done[+A](a: A) extends Trampoline[A] 14 case class More[+A](k: () => Trampoline[A]) extends Trampoline[A]

這樣咱們能夠把State.flatMap調整成如下這樣:

 1 case class State[S,A](runS: S => Trampoline[(A,S)]) {  2     def flatMap[B](f: A => State[S,B]): State[S,B] = State[S,B] {  3         s => More(() => {  4 // val (a1,s1) = runS(s).runT  5 // More(() => f(a1) runS s1)
 6           runS(s) flatMap {   // runS(s) >>> Trampoline
 7             case (a1,s1) => More(() => f(a1) runS s1)  8  }  9  }) 10  } 11   def map[B](f: A => B): State[S,B] = flatMap( a => unit(f(a))) 12 }

如今咱們把遞歸算法都推到了Trampoline.flatMap這兒了。不過Trampoline.flatMap的runT引用f(runT)不在尾遞歸位置,因此這樣調整還不足夠。看來核心仍是要解決flatMap尾遞歸問題。咱們能夠再爲Trampoline增長一個狀態結構FlatMap而後把flatMap函數引用變成類型實例構建(type construction):

1 case class Done[+A](a: A) extends Trampoline[A] 2 case class More[+A](k: () => Trampoline[A]) extends Trampoline[A] 3 case class FlatMap[A,B](sub: Trampoline[A], k: A => Trampoline[B]) extends Trampoline[B]

case class FlatMap這種Trampoline狀態意思是先引用sub而後把結果傳遞到下一步k再運行k:基本上是沿襲flatMap功能。再調整Trampoline.resume, Trampoline.flatMap把FlatMap這種狀態考慮進去:

 1 trait Trampoline[+A] {  2  final def runT: A = resume match {  3       case Right(a) => a  4       case Left(k) => k().runT  5  }  6   def flatMap[B](f: A => Trampoline[B]): Trampoline[B] = {  7       this match {  8 // case Done(a) => f(a)  9 // case More(k) => f(runT)
10       case FlatMap(a,g) => FlatMap(a, (x: Any) => g(x) flatMap f) 11       case x => FlatMap(x, f) 12  } 13  } 14   def map[B](f: A => B) = flatMap(a => Done(f(a))) 15   def resume: Either[() => Trampoline[A], A] = this match { 16     case Done(a) => Right(a) 17     case More(k) => Left(k) 18     case FlatMap(a,f) => a match { 19         case Done(v) => f(v).resume 20         case More(k) => Left(() => k() flatMap f) 21         case FlatMap(b,g) => FlatMap(b, (x: Any) => g(x) flatMap f).resume 22  } 23  } 24 } 25 case class Done[+A](a: A) extends Trampoline[A] 26 case class More[+A](k: () => Trampoline[A]) extends Trampoline[A] 27 case class FlatMap[A,B](sub: Trampoline[A], k: A => Trampoline[B]) extends Trampoline[B]

在以上對Trampoline的調整裏咱們引用了Monad的結合特性(associativity):

FlatMap(FlatMap(b,g),f) == FlatMap(b,x => FlatMap(g(x),f)

從新右結合後咱們能夠用FlatMap正確表達複數步驟的運算了。

如今再試着運行zip:

 1 def zip[A](as: List[A]): List[(A,Int)] = {  2  as.foldLeft(  3  unit[Int,List[(A,Int)]](List()))(  4       (acc,a) => for {  5         xs <- acc  6         n <- getState[Int]  7         _ <- setState[Int](n + 1)  8  } yield (a,n) :: xs  9     ).runS(0).runT._1.reverse 10 }                                                 //> zip: [A](as: List[A])List[(A, Int)]
11 zip((1 to 10000).toList)                          //> res0: List[(Int, Int)] = List((1,0), (2,1), (3,2), (4,3), (5,4), (6,5), (7,

此次運行正常,再不出現StackOverflowError了。

實際上咱們能夠考慮把Trampoline看成一種通用的堆棧溢出解決方案。

咱們首先能夠利用Trampoline的Monad特性來調控函數引用,以下:

1 val x = f() 2 val y = g(x) 3 h(y) 4 //以上這三步函數引用能夠寫成:
5 for { 6  x <- f() 7  y <- g(x) 8  z <- h(y) 9 } yield z

舉個實際例子:

 1 implicit def step[A](a: => A): Trampoline[A] = {  2     More(() => Done(a))  3 }                                                 //> step: [A](a: => A)ch13.ex1.Trampoline[A]
 4 def getNum: Double = 3                            //> getNum: => Double
 5 def addOne(x: Double) = x + 1                     //> addOne: (x: Double)Double
 6 def timesTwo(x: Double) = x * 2                   //> timesTwo: (x: Double)Double
 7 (for {  8     x <- getNum  9     y <- addOne(x) 10     z <- timesTwo(y) 11 } yield z).runT                                   //> res6: Double = 8.0

又或者:

1 def fib(n: Int): Trampoline[Int] = { 2     if (n <= 1) Done(n) else for { 3         x <- More(() => fib(n-1)) 4         y <- More(() => fib(n-2)) 5     } yield x + y 6 }                                                 //> fib: (n: Int)ch13.ex1.Trampoline[Int]
7 (fib(10)).runT                                    //> res7: Int = 55

從上面得出咱們能夠用flatMap來對Trampoline運算進行流程控制。另外咱們還能夠經過把多個Trampoline運算交叉組合來實現並行運算:

 1 trait Trampoline[+A] {  2  final def runT: A = resume match {  3       case Right(a) => a  4       case Left(k) => k().runT  5  }  6   def flatMap[B](f: A => Trampoline[B]): Trampoline[B] = {  7       this match {  8 // case Done(a) => f(a)  9 // case More(k) => f(runT)
10       case FlatMap(a,g) => FlatMap(a, (x: Any) => g(x) flatMap f) 11       case x => FlatMap(x, f) 12  } 13  } 14   def map[B](f: A => B) = flatMap(a => Done(f(a))) 15   def resume: Either[() => Trampoline[A], A] = this match { 16     case Done(a) => Right(a) 17     case More(k) => Left(k) 18     case FlatMap(a,f) => a match { 19         case Done(v) => f(v).resume 20         case More(k) => Left(() => k() flatMap f) 21         case FlatMap(b,g) => FlatMap(b, (x: Any) => g(x) flatMap f).resume 22  } 23  } 24   def zip[B](tb: Trampoline[B]): Trampoline[(A,B)] = { 25     (this.resume, tb.resume) match { 26         case (Right(a),Right(b)) => Done((a,b)) 27         case (Left(f),Left(g)) => More(() => f() zip g()) 28         case (Right(a),Left(k)) => More(() => Done(a) zip k()) 29         case (Left(k),Right(a)) => More(() => k() zip Done(a)) 30  } 31  } 32 } 33 case class Done[+A](a: A) extends Trampoline[A] 34 case class More[+A](k: () => Trampoline[A]) extends Trampoline[A] 35 case class FlatMap[A,B](sub: Trampoline[A], k: A => Trampoline[B]) extends Trampoline[B]

咱們能夠用這個zip函數把幾個Trampoline運算交叉組合起來實現並行運算:

1 def hello: Trampoline[Unit] = for { 2     _ <- print("Hello ") 3     _ <- println("World!") 4 } yield ()                                        //> hello: => ch13.ex1.Trampoline[Unit]
5 
6 (hello zip hello zip hello).runT                  //> Hello Hello Hello World! 7                                                   //| World! 8                                                   //| World! 9                                                   //| res8: ((Unit, Unit), Unit) = (((),()),())

用Trampoline能夠解決StackOverflow這個大問題。如今咱們能夠放心地進行泛函編程了。

相關文章
相關標籤/搜索