我们终于从Schedule到Execute又写到BlockManager了。这是carolz最早开始好奇的地方,但是不看完前两个部分真是完全没办法看这个部分呢。所以其实Spark的几个大块内容(schedule, execute, blockmanager, network)我们只剩下这个没有好好看过了,不禁有点小激动呢,感觉距离源码看完的一天又近了一步。话说最近读了Intel关于Spark GC优化的文章,感觉改源码这件事情还是蛮炫酷的。虽然之前我们也动手改过,然而并不能算一次比较成功的修改。前路漫漫~

好了闲话少叙,在SparkExecute中我们略过了ShuffledRDD之前的RDD是怎么往bucket里写数据的,也略过了Shuffled是怎么读这个数据的。在这篇文章里我们正要好好看看这件事情。

####1 写数据####

写数据是一个ShuffleMapTask的一部分。当这个Task执行到最后一个RDD.compute结束就应该写数据了。

我们从ShuffleMapTask的runTask开始看起。看shuffle.writers(bucketId).write(pair)

我们截取ShuffleMapTask的runTask的跟这个shuffle.writers有关的片段:

1
2
3
4
5
6
7
8
9
10
val shuffleBlockManager = blockManager.shuffleBlockManager
var shuffle: ShuffleWriterGroup = null
shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
for(elem <- rdd.iterator(split, context)){
val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
shuffle.writers(bucketId).write(pair)
}

首先看一下bucketId。这里的dep.partitioner得到的是combineByKey的时候传进来的new HashPartitioner(numPartitions)。

=================关于HashPartitioner=================

Partitioner.scala

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/**
* A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using
* Java's `Object.hashCode`.
*
* Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
* so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
* produce an unexpected or incorrect result.
*/
class HashPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
def getPartition(key: Any): Int = key match {
case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
}
override def equals(other: Any): Boolean = other match {
case h: HashPartitioner =>
h.numPartitions == numPartitions
case _ =>
false
}
}

Utils.scala

1
2
3
4
5
6
7
8
/* Calculates 'x' modulo 'mod', takes to consideration sign of x,
* i.e. if 'x' is negative, than 'x' % 'mod' is negative too
* so function return (x % mod) + mod in that case.
*/
def nonNegativeMod(x: Int, mod: Int): Int = {
val rawMod = x % mod
rawMod + (if (rawMod < 0) mod else 0)
}

总之HashParititioner.getPartitioner(key: Any)的意思就是根据key,返回一个hash值

if(key.hashCode > 0){
  key.hashCode % numPartitions
}else{
  (key.hashCode % numPartitions) + numPartitions
}

=================HashPartitioner看完啦===============

直接来看这个shuffle,根据赋值我们知道它是一个ShuffleWriterGroup,然后shuffle.writers是一个Array[BlockObjectWriter],也就是说,根据bucketId,我们取出了一个BlockObjectWriter,然后往里面写一个pair。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
fileGroup = getUnusedFileGroup()
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
}
} else {
//上面那个if语句可以暂时忽略掉,因为那个值默认是false
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
val blockFile = blockManager.diskBlockManager.getFile(blockId)
// Because of previous failures, the shuffle file may already exist on this machine.
// If so, remove it.
if (blockFile.exists) {
if (blockFile.delete()) {
logInfo(s"Removed existing shuffle file $blockFile")
} else {
logWarning(s"Failed to remove existing shuffle file $blockFile")
}
}
blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
}
}

由此可知shuffle.writers(bucketId)获得一个DiskWriter。这个DiskWriter是由(blockId, blockFile)。而这个blockFile也是由blockId得到的。那么这个blockId又是怎么来的呢?

blockId是由(shuffleId, mapId, bucketId)得到的。shuffleId就是dep.shuffleId,mapId就是partitionId,bucketId应该就是传进来的参数?

首先看一下ShuffleBlockId的格式

BlockId.scala

1
2
3
4
5
@DeveloperApi
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}

再去看一眼blockFile

DiskBlockManager.scala

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(filename)
val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
// Create the subdirectory if it doesn't already exist
var subDir = subDirs(dirId)(subDirId)
if (subDir == null) {
subDir = subDirs(dirId).synchronized {
val old = subDirs(dirId)(subDirId)
if (old != null) {
old
} else {
val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
newDir.mkdir()
subDirs(dirId)(subDirId) = newDir
newDir
}
}
}
new File(subDir, filename)
}

然后文件有了,就要拿到DiskWriter了。

BlockManager.scala

1
2
3
4
5
6
7
8
9
10
11
12
13
14
/**
* A short circuited method to get a block writer that can write data directly to disk.
* The Block will be appended to the File specified by filename. This is currently used for
* writing shuffle files out. Callers should handle error cases.
*/
def getDiskWriter(
blockId: BlockId,
file: File,
serializer: Serializer,
bufferSize: Int): BlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites)
}

BlockObjectWriter.scala

1
2
3
def write(i: Int): Unit = callWithTiming(out.write(i))
override def write(b: Array[Byte]) = callWithTiming(out.write(b))
override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))

这里的callWithTimining的作用是记录下写文件花费的时间,然后把这次写文件的时间记录下来。

至此,文件就写好了。

综上所述,我们写了几个文件。这个几个文件名是"shuffle_" + shuffleId + "_" + mapId + "_" + reduceId的格式然后进行hash之后的结果。其中,

  • shuffleId就是dep.shuffleId(每次自增1)
  • mapId就是partitionId
    • 就是当初新建ShuffleMapTask的时候的p。(p <- 0 until stage.numPartitions)
  • bucketId应该就是传进来的参数
    • bucketId与每个pair._1的hashCode有关(基本就是pair._1.hashCode % numPartitions)。val bucketId = dep.partitioner.getPartition(pair._1)。

这里还会涉及到文件子目录的问题,我们先放一下,反正怎么写的到时候怎么读出来就行了。

在pair都写到bucket了之后,我们还要做一些别的事情。

  • Commit the writes. Get the size of each bucket block (total block size).
  • Update shuffle metrics.
    • new MapStatus(blockManager.blockManagerId, compressedSizes)
  • Release the writers back to the shuffle block manager.
  • Execute the callbacks on task completion.

至此这个Task就算是完成了。

======================小看一下MapStatus========================

这里MapStatus的建立非常重要,等下取数据也是要根据MapStatus取的。这货的建立需要blockManager.blockManagerId,这个值哪里来的呢?

在executor初始化时,每一个executor都会有  

val _env = SparkEnv.create(conf, executorId, slaveHostname, 0, isDriver = false, isLocal = false)
SparkEnv.set(env)

也就是说,每个executor都有自己的SparkEnv。每个SparkEnv里又都新建了一个BlockManager,每个BlockManager都有一个BlockManagerId(executorId, connectionManager.id.host, connectionManager.id.port, nettyPort)

BlockManagerId.scala

override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort)

===========================看完了==============================

####2 读数据####

接下来我们就要从ShuffledRDD开始看读数据的过程了。

ShuffledRDD.scala

1
2
3
4
5
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
val ser = Serializer.getSerializer(serializer)
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
}

所以事情就是这样,这里需要根据shuffledId和split.index取数据了。

========================关于shuffledId和split.index====================

####1 ShuffledId####

早在spark scheduling这篇文章里讲getMissingParentStages的时候,有一个遍历rdd.dependencies的语句。这句话会调用各个rdd override的getDependencies,我们来看一下ShuffledRDD的getDependencies这个函数只会被调用一次

1
2
3
override def getDependencies: Seq[Dependency[_]] = {
List(new ShuffleDependency(prev, part, serializer))
}

再看一下Dependency.scala

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/**
* :: DeveloperApi ::
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
*/
@DeveloperApi
class ShuffleDependency[K, V](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializer: Serializer = null)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}

这里的rdd.context是一个SparkContext,rdd.context.newShuffleId其实就是在上一个shuffleId的基础上自增1而已。

至此我们已经知道了ShuffledId是在划分stage的时候就分配了的,值是每次自增1的

####2 split.index####

这里的split来自iterator传进的参数,这个iterator往上追溯是在ShuffleMapTask新建的时候给赋值的。

ShuffleMapTask.scala

1
var split = if (rdd == null) null else rdd.partitions(partitionId)

rdd.partitions(partitionId)得到的是一个Partition,这个Partition的index又是从哪来的呢?

我们在start from HadoopRDD中曾经说过,当rdd.partitions第一次被调用的时候,就会调用这个RDD override的getPartitions这里我们假设就是第一次调用吧(因为再也找不到上层了)于是就有了这个结果:

ShufffledRDD.scala

1
2
3
override def getPartitions: Array[Partition] = {
Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
}

所以说其实ShuffledRDD有其独特的Partition。

1
2
3
4
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx //这就是split.index
override def hashCode(): Int = idx
}

所以split.index就是那个i。就是从0到part.numPartitions的i。

==========================解释结束==========================

看到这里,我们就可以直接去取数据了。

SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)

这里的SparkEnv.get.shuffleFetcher实际上是BlockStoreShuffleFetcher

BlockStoreShuffleFetcher.scala

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
override def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
val startTime = System.currentTimeMillis
//从Master得到map output file 的location
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
case Some(block) => {
block.asInstanceOf[Iterator[T]]
}
case None => {
blockId match {
case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block")
}
}
}
}
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
})
new InterruptibleIterator[T](context, completionIter)
}
}

首先就是要getServerStatus(shuffleId, reduceId),这个函数主要做的事情就是:

Called from executors to get the server URIs and output sizes of the map outputs of a given shuffle.

首先就是根据shuffleId看看map output在不在本地啊,如果不在的话,就要fetch了。然后看看是不是有别人在fetch啊,如果有的话就要等一等啊。等完之后就执行fetchedStatuses = mapStatuses.get(shuffleId).orNull这一步是在考虑如果有人在这个get和刚才的fetching.synchronized之间也在fetch怎么办。总之这里都是考虑同步的。

下面截取getServerStatus(shuffleId, reduceId)比较重要的部分

1
2
3
4
5
6
7
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)

askTracker主要就是

val future = trackerActor.ask(message)(timeout)
  Await.result(future, timeout)

trackerActor是一个ActorRef,这是akka里面的一个类。发出去的消息是一个MapOutputTrackerMessage。总之就是向Master询问output locations然后得到回复,返回MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
// throw a FetchFailedException.
private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
assert (statuses != null)
statuses.map {
status =>
if (status == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
} else {
(status.location, decompressSize(status.compressedSizes(reduceId)))
}
}
}

通过Array的map方法,可以把statuses里的status全部变成(status.location, size)的格式。

可以看出来,获得的地址信息是(address, size)数组格式的,也就是说会有很多个地址(shuffle的数据来自多个嘛,可以理解)。BlockStoreShuffleFetcher会用zipWithIndex方法把这些地址编上号(从0开始累加),变成((address, size), index),然后以((index, size))的格式放到splitsByAddress(address, ArrayBuffer())里。这个address其实就是BlockManagerId。总之就是得到了地址啥的,然后就要从远端去获取数据了。

val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)

先看blockManager.getMultile

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/**
* Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns
* an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined
* fashion as they're received. Expects a size in bytes to be provided for each block fetched,
* so that we can control the maxMegabytesInFlight for the fetch.
*/
def getMultiple(
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer): BlockFetcherIterator = {
val iter =
if (conf.getBoolean("spark.shuffle.use.netty", false)) {
new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
} else {
new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
}
iter.initialize()
iter
}

返回值是一个initialize()过后的BasicBlockFetcherIterator(因为默认是不用Netty的)

这个初始化做的事情有:

  • Split local and remote blocks
  • Add the remote requessts into our queue in a random order
  • Send out initial requests for blocks, up to our maxBytesInFlight
  • Get Local Blocks
    • Get the local blocks while remote blocks are being fetched. Note that it’s okay to do these all at once because they will just memory-map some files, so they won’t consume any memory that might exceed our maxBytesInFlight

然而我们并没有在代码中看出从远端获取能跟本地获取同时进行╮(╯▽╰)╭

ok,然后执行blockFetcherItr.flatMap(unpackBlock)

然后我不想看了,明天继续。