Spark广播变量与累加器 2015-04-24 17:00

广播变量

说明

Spark中分布式执行的代码需要传递到各个Executor的Task上运行。对于一些只读、固定的数据(比如从DB中读出的数据),每次都需要Driver广播到各个Task上,这样效率低下。广播变量允许将变量只广播(提前广播)给各个Executor。该Executor上的各个Task再从所在节点的BlockManager获取变量,而不是从Driver获取变量,从而提升了效率。

一个Executor只需要在第一个Task启动时,获得一份Broadcast数据,之后的Task都从本节点的BlockManager中获取相关数据。

样例代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import org.apache.spark.{SparkConf, SparkContext}

/**
  * Usage: BroadcastTest [slices] [numElem] [broadcastAlgo] [blockSize]
  */
object BroadcastTest {
  def main(args: Array[String]) {

    val bcName = if (args.length > 2) args(2) else "Http"
    val blockSize = if (args.length > 3) args(3) else "4096"

    val sparkConf = new SparkConf().setAppName("Broadcast Test")
      .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroadcastFactory")
      .set("spark.broadcast.blockSize", blockSize)
    val sc = new SparkContext(sparkConf)

    val slices = if (args.length > 0) args(0).toInt else 2
    val num = if (args.length > 1) args(1).toInt else 1000000

    val arr1 = (0 until num).toArray

    for (i <- 0 until 3) {
      println("Iteration " + i)
      println("===========")
      val startTime = System.nanoTime
      val barr1 = sc.broadcast(arr1)
      val observedSizes = sc.parallelize(1 to 10, slices).map(_ => barr1.value.size)
      // Collect the small RDD so we can print the observed sizes locally.
      observedSizes.collect().foreach(i => println(i))
      println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6))
    }

    sc.stop()
  }
}

注意事项

第27行要使用barr1.value而不是arr1来访问广播变量,否则无效。

相关配置

  • spark.broadcast.compress

是否压缩广播变量,默认为True。

  • spark.broadcast.factory

org.apache.spark.broadcast.HttpBroadcastFactory或org.apache.spark.broadcast.TorrentBroadcastFactory,默认为Torrent。

  • spark.broadcast.blockSize

类型为Torrent时的块大小,以KB为单位。如果太大,会导致并行度过小。如果太小,会影响BlockManager的性能。

累加器

说明

通过SparkContext.accumulator(v)来创建accumulator类型的变量,然后运行的task可以使用“+=”操作符来进行累加。但是task不能读取到该变量,只有driver program能够读取(通过.value),这也是为了避免使用太多读写锁吧。

示例代码

1
2
3
val accum = sc.accumulator(0, "My Accumulator")
sc.parallelize(Array(1, 2, 3, 4)).foreach(x => accum += x)
println(accum.value)

自定义累加器类型

累加器类型除Spark自带的int、float、Double外,也支持开发人员自定义。方法是继承AccumulatorParam[Vector]。

Tags: #Spark    Post on Spark