从源码的角度稍微来追踪一下Spark中的任务调度是什么样的。网上好多分析的版本跟carolz看的spark版本不一样啊,真忧桑。顺便说一句,Spark源码是用Scala写的,对于新手,那些语法糖让我们实在痛苦>。<,稍稍坚持一下~那么,就开始吧。

Spark Version: 1.0.1

让我们从一个例子开始:

val textFile = sc.textFile("readme.md")
textFile.filter(line=>line.contains("spark")).count()  

这是一个统计readme.md里含有”spark”这个单词的行数的一个程序。这是一个简单的没有shuffle和reduce的job。

在之前分析RDD的文章中我们讲过,这里textFile是一个MappedRDD。然后我们对这个RDD进行filter操作。这是一个对RDD的基本操作,我们可以在RDD.scala里找到。

def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f))

=================================我是闭包分割线===============================

可以看到这里生成了一个FilteredRDD,并把line=>line.contains("spark")作为参数传了进来,我们先来跟踪sc.clean(f),默默打开SparkContext.scala

private[spark] def clean[F <: AnyRef](f: F): F = {
    ClosureCleaner.clean(f)
    f
}

然后去spark.util.ClosureCleaner里找到ClosureCleaner.clean这个函数有点长,我就不贴了。我们先搞清楚它是在做什么。首先,从函数的名字可以看出这是在做闭包清理。要知道什么是闭包清理首先就要知道什么是闭包。

这篇文章我们可以知道:闭包就是拥有对outer函数/类的变量的引用,从而可以在外面函数栈执行结束以后,依然握有外面函数栈/堆变量的引用,并可以改变他们的值。比如对于匿名函数,我们就可以体会出这个闭包是你怎么回事。所以“闭包有对外部变量的引用的能力,这个能力是有潜在风险的。首先它会影响变量的GC,另外他会影响函数对象的序列化”。

然后我们继续跟读代码。可以看到这段代码的注释TODO: cache outerClasses / innerClasses / accessedFields。读了上面这段我们就清楚了,它需要把闭包函数的outerClasses/innerClasses/accessedField全部都cache下来,免得这些可被闭包函数访问的区域被GC丢掉了。

=================================闭包就讲到这里啦===============================

我们继续回到主代码,生成FilteredRDD之后,对这个RDD进行了一个count()操作。这也是一个RDD的基本操作,在RDD.scala中。

def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum

我们知道,对于RDD的操作可以分为Transformation和Action两种,count()是一个Action操作,这种操作会对之前的RDD进行真正的执行,也就是执行sc.runJob。OK,现在我们打开SparkContext.scala,顺着sc.runJob追下去。

这个函数在SparkContext里有多个不同的重载,我们找到上面调用的那个(这里进行了很多重载函数的调用嵌套,不想看直接跳到最后一个runJob)

/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
    runJob(rdd, func, 0 until rdd.partitions.size, false)
}

然后这个又调用了下面这个重载

/**
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
*/
def runJob[T, U: ClassTag](
    rdd: RDD[T],
    func: Iterator[T] => U,
    partitions: Seq[Int],
    allowLocal: Boolean
): Array[U] = {
    runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal)
}

然后又调用了下面这个重载

/**
* Run a function on a given set of partitions in an RDD and return the results as an array. The
* allowLocal flag specifies whether the scheduler can run the computation on the driver rather
* than shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassTag](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    allowLocal: Boolean
): Array[U] = {
    val results = new Array[U](partitions.size)
    runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
    results
}

最终调用到下面这个重载

/**
* Run a function on a given set of partitions in an RDD and pass the results to the given
* handler function. This is the main entry point for all actions in Spark. The allowLocal
* flag specifies whether the scheduler can run the computation on the driver rather than
* shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassTag](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    allowLocal: Boolean,
    resultHandler: (Int, U) => Unit) {
    if (dagScheduler == null) {
        throw new SparkException("SparkContext has been shutdown")
    }
    val callSite = getCallSit //这是一个函数调用链的记录,不用太在意
    val cleanedFunc = clean(func)
    logInfo("Starting job: " + callSite)
    val start = System.nanoTime
    dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
        resultHandler, localProperties.get) //看这里看这里
    logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
    rdd.doCheckpoint()
}

可以看到这个函数一来就对dagScheduler是否为空进行了一个判断,如果为空是要报错的。我们先占个坑,其实可以看到dagSchedulerSparkContext新建的时候就被初始化了(貌似事实也正是如此),或是以后可以研究一下master开启的时候究竟做了哪些事情

OK,现在以上代码中最关键的一句就是dagScheduler.runJob了。我们来追一下DAGScheduler.scala

def runJob[T, U: ClassTag](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    callSite: String,
    allowLocal: Boolean,
    resultHandler: (Int, U) => Unit,
    properties: Properties = null)
{
    //提交了一个job然后等待完成啦
    val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
    waiter.awaitResult() match {
        case JobSucceeded => {}
        case JobFailed(exception: Exception) =>
            logInfo("Failed to run " + callSite)
            throw exception
    }
  }

就是提交了一个job然后等待job完成,去追一下submitJob

/**
* Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
* can be used to block until the the job finishes executing or can be used to cancel the job.
*/
def submitJob[T, U](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    callSite: String,
    allowLocal: Boolean,
    resultHandler: (Int, U) => Unit,
    properties: Properties = null): JobWaiter[U] =
{
    // Check to make sure we are not launching a task on a partition that does not exist.
    //这里的partition是在runJob的时候就传进来了,用0 until rdd.partition.size所以会存在的
    val maxPartitions = rdd.partitions.length
    partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
        throw new IllegalArgumentException(
            "Attempting to access a non-existent partition: " + p + ". " +
            "Total number of partitions: " + maxPartitions)
    }
    //每次submit一个job,jobId都加1
    val jobId = nextJobId.getAndIncrement()
    if (partitions.size == 0) {
        return new JobWaiter[U](this, jobId, 0, resultHandler)
    }

    assert(partitions.size > 0)
    val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
    val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
    eventProcessActor ! JobSubmitted(
        jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
    waiter
}

这里有一个问题,为什么当partitions.size==0时,直接return new JobWaiter[U](this, jobId, 0, resultHandler)?这里的resultHandler在之前被传入的值是(index, res) => results(index) = res,与用户指定的func并没有什么关系呀。
关于这个问题,我们现在跟踪的这个程序并不存在,因为partition是在使用0 until rdd.parition.size来初始化的,所以一定有partition,所以我们先放一放这个问题。

我们先看partitions.size!=0的情况。这里首先对func进行了一个处理,去看看这个处理做了什么事情。

asInstanceOf是scala语言的一个类型转换机制。把原来func里的T和U都变成了_

然后new了一个JobWaiter,在JobWaiter里有一段注释很能说明问题:

/**
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
*/

然后把这个waiter和之前的一些变量加上JobSubmitted,组成一个消息发给eventProcessActor

然后我们发现这个eventProcessActor也在DAGScheduler.scala中,在DAGSchedulerEventProcessActor类当中。在这个类里有一个receive函数,我们看看当收到JobSubmitted消息时,会怎么处理这个消息。

/**
* The main event loop of the DAG scheduler.
*/
def receive = {
    case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
        dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
            listener, properties)
    ....//还有很多case我们先省掉
}

噢,调用了dagScheduler.handleJobSubmitted。这里说一句,JYY曾经教导过我们函数一般写得不超过一屏,不然不好debug,这个函数在我22.5寸的屏幕上差不多有一屏幕长了,算是一个蛮长的函数,不过逻辑并没有那么复杂。

private[scheduler] def handleJobSubmitted(jobId: Int,
    finalRDD: RDD[_],
    func: (TaskContext, Iterator[_]) => _,
    partitions: Array[Int],
    allowLocal: Boolean,
    callSite: String,
    listener: JobListener,
    properties: Properties = null)
{
    var finalStage: Stage = null
    try {
        // New stage creation may throw an exception if, for example, jobs are run on a
        // HadoopRDD whose underlying HDFS files have been deleted.
        finalStage = newStage(finalRDD, partitions.size, None, jobId, Some(callSite))
    } catch {
        case e: Exception =>
            logWarning("Creating new stage failed due to exception - job: " + jobId, e)
            listener.jobFailed(e)
            return
    }
    if (finalStage != null) {
        val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
        clearCacheLocs()
        logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format(
            job.jobId, callSite, partitions.length, allowLocal))
        logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")")
        logInfo("Parents of final stage: " + finalStage.parents)
        logInfo("Missing parents: " + getMissingParentStages(finalStage))
        if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
            // Compute very short actions like first() or take() with no parent stages locally.
            //对于比较小的任务,本地算算就好啦
            listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties))
            runLocally(job)
        } else {
            //主要看这里
            jobIdToActiveJob(jobId) = job //这是一个HashMap
            activeJobs += job
            resultStageToJob(finalStage) = job //这是一个HashMap
            listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray,properties))
            submitStage(finalStage)
        }
    }
    submitWaitingStages()
}

可以看到一来就从finalStage开始。对于这段解读,我们参考了这篇文章

首先从finalRDD开始新建了一个stage,在我们这个例子里finalRDD就是FilteredRDD。然后从这个stage开始new了一个ActiveJob,这个数据结构没什么好讲的,一共就几行。

然后判断了一下,如果任务比较小,而且没有parent stages就可以本地做了(直接runLocally(job))。否则需要submitStage(finalStage)

不管是本地还是提交stage的做法,我们注意到都有listenerBus.post的操作。listenerBus是一个LiveListenerBus类型的实例,在DAGScheduler被创建的时候作为参数传入。这个类可以看做是一个用于存储所有Spark状态相关监听器的ArrayList。在这里我们把一个SparkListenerJobStart传入,这是一个DevelopAPI。先不管这个了。

=================================我是本地计算分割线===============================

然后先来看下runLocally。这是一个调用runLocallyWithinThread(job)的线程。一旦被调用,线程直接开始run。run了什么呢?我们来看一下。

// Broken out for easier testing in DAGSchedulerSuite.
protected def runLocallyWithinThread(job: ActiveJob) {
    var jobResult: JobResult = JobSucceeded
    try {
        SparkEnv.set(env)
        val rdd = job.finalStage.rdd
        val split = rdd.partitions(job.partitions(0))
        val taskContext =
             new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
        try {
            val result = job.func(taskContext, rdd.iterator(split, taskContext))
            job.listener.taskSucceeded(0, result)
        } finally {
            taskContext.executeOnCompleteCallbacks()
        }
    } catch {
        case e: Exception =>
            val exception = new SparkDriverExecutionException(e)
            jobResult = JobFailed(exception)
            job.listener.jobFailed(exception)
    } finally {
        val s = job.finalStage
        stageIdToJobIds -= s.id    // clean up data structures that were populated for a local job,
        stageIdToStage -= s.id     // but that won't get cleaned up via the normal paths through
        stageToInfos -= s          // completion events or stage abort
        jobIdToStageIds -= job.jobId
        listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult))
    }
}

一来就判断说Job Succeed了,后面如果catch到错,再改。计算好以后会把结果送到job.listener.taskSucceeded(0, result)

=================================本地计算就讲到这里啦===============================

然后再来看看要送到远程的job。远程的job比本地多了一系列往HashMap中添加映射的过程,然后多了一个submitStage(finalStage)。我们从这个submit看起。同样在DAGScheduler.scala

/** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
    val jobId = activeJobForStage(stage)
    if (jobId.isDefined) {
        logDebug("submitStage(" + stage + ")")
        if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
            //这里是在找父依赖
            val missing = getMissingParentStages(stage).sortBy(_.id)
            logDebug("missing: " + missing)
            if (missing == Nil) {
                logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
                submitMissingTasks(stage, jobId.get)
                runningStages += stage
            } else {
                for (parent <- missing) {
                    submitStage(parent) //递归啦!递归啦!递归啦!!!
                }
                waitingStages += stage
            }
        }
    } else {
        abortStage(stage, "No active job for stage " + stage.id)
    }
}

=================================找missing stage分割线===============================

我们先来看一下是怎么getMissingParentStages(stage)的。

private def getMissingParentStages(stage: Stage): List[Stage] = {
    val missing = new HashSet[Stage] //放missing的Stage的
    val visited = new HashSet[RDD[_]]
    def visit(rdd: RDD[_]) {
        if (!visited(rdd)) {
            visited += rdd
            if (getCacheLocs(rdd).contains(Nil)) {
                for (dep <- rdd.dependencies) {
                    dep match {
                        case shufDep: ShuffleDependency[_,_] =>
                             val mapStage = getShuffleMapStage(shufDep, stage.jobId)
                            if (!mapStage.isAvailable) {
                                 missing += mapStage
                             }
                        case narrowDep: NarrowDependency[_] =>
                            visit(narrowDep.rdd)
                    }
                }
            }
        }
    }
    visit(stage.rdd)
    missing.toList
}

就是顺着依赖链往上找所有missing的Stage,加到missing list中,返回。这里注意有一个宽依赖和窄依赖的判断。如果是窄依赖,直接递归上一个rdd,并不会新加一个Stage。只有对于要shuffle的宽依赖才会new Stage。对于宽依赖,会getShffleMapStage并放到missing list中。在getShffleMapStage中,会调用newOrUsedStage,这个函数是专门为shuffle stage准备的。那么它做了什么事情呢?

 /**
 * Create a shuffle map Stage for the given RDD.  The stage will also be associated with the
 * provided jobId.  If a stage for the shuffleId existed previously so that the shuffleId is
 * present in the MapOutputTracker, then the number and location of available outputs are
 * recovered from the MapOutputTracker
 */
private def newOrUsedStage(
    rdd: RDD[_],
    numTasks: Int,
    shuffleDep: ShuffleDependency[_,_],
    jobId: Int,
    callSite: Option[String] = None)
  : Stage =
{
  //注意到如果shuffle是会new Stage的。
  val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite)
  if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
    val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
    val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
    for (i <- 0 until locs.size) {
      stage.outputLocs(i) = Option(locs(i)).toList   // locs(i) will be null if missing
    }
    stage.numAvailableOutputs = locs.count(_ != null)
  } else {
    // Kind of ugly: need to register RDDs with the cache and map output tracker here
    // since we can't do it in the RDD constructor because # of partitions is unknown
    logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
    mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size)
  }
  stage
}

这里占个坑现在并不想看。

=================================missing stage找完啦===============================

让我们回到找所有的missing stage,对于missing list != Nil的情况,说明我们还没有找到头,还没有划分好所有的Stage,此时递归调用submitStage(parent)找。直到missing list == Nil,就说明我们已经找到了所有的父依赖,并根据Shuffle划分好了所有的Stage。这个时候就可以submitMissingTasks了。

/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
    logDebug("submitMissingTasks(" + stage + ")")
    // Get our pending tasks and remember them in our pendingTasks entry
    val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
    myPending.clear()
    var tasks = ArrayBuffer[Task[_]]()
    if (stage.isShuffleMap) {
        for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
            //这里是在分这个stage在哪个worker上执行
            val locs = getPreferredLocs(stage.rdd, p)
            tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
        }
    } else {
        // This is a final stage; figure out its job's missing partitions
        //没有Shuffle的就是最后一个stage了
        val job = resultStageToJob(stage)
        for (id <- 0 until job.numPartitions if !job.finished(id)) {
            val partition = job.partitions(id)
            val locs = getPreferredLocs(stage.rdd, partition)
            tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
        }
    }
    //这里是在干什么?如果为空就放到'default'pool去执行呢?
    val properties = if (jobIdToActiveJob.contains(jobId)) {
        jobIdToActiveJob(stage.jobId).properties
    } else {
        // this stage will be assigned to "default" pool
        null
    }

    // must be run listener before possible NotSerializableException
    // should be "StageSubmitted" first and then "JobEnded"
    listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))

    //表示这个Stage还有task没完成
    if (tasks.size > 0) {
        // Preemptively serialize a task to make sure it can be serialized. We are catching this
        // exception here because it would be fairly hard to catch the non-serializable exception
        // down the road, where we have several different implementations for local scheduler and
        // cluster schedulers.
        try {
            SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
        } catch {
            case e: NotSerializableException =>
            abortStage(stage, "Task not serializable: " + e.toString)
            runningStages -= stage
            return
        }

        logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
        myPending ++= tasks
        logDebug("New pending tasks: " + myPending)
        //看这里 看这里
        taskScheduler.submitTasks(
          new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
          stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
    } else {
        //表示这个Stage的所有task都已经完成了
        logDebug("Stage " + stage + " is actually done; %b %d %d".format(
        stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
        runningStages -= stage
    }
}

这个函数首先判断是不是一个shuffle的stage:如果是,对于每个partitionnew ShuffleMapTask;否则,对于每个partitionnew ResultTask。这里有一个小设计,每个partition可以选prefer machine。先看一下这个再回到主线。

===============================可以选择partition prefer的machine哟===============================

对于每个stage,对其每个partition,都会有一个getPreferredLocs,这是在判断这个partition哪个machine比较喜欢。如果这个partition被cache了,就选cache它的machine。如果没有,看看这个RDD有没有什么比较喜欢的location。如果还没有,看看这个RDD有没有窄依赖,选择第一个窄依赖的第一个partition所偏爱的地方。理想状态下要考虑传输数据的大小,但是暂时并没有考虑。不知道后面的版本有没有改进。

/**
 * Synchronized method that might be called from other threads.
 * @param rdd whose partitions are to be looked at
 * @param partition to lookup locality information for
 * @return list of machines that are preferred by the partition
 */
private[spark]
def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized {
  // If the partition is cached, return the cache locations
  val cached = getCacheLocs(rdd)(partition)
  if (!cached.isEmpty) {
    return cached
  }
  // If the RDD has some placement preferences (as is the case for input RDDs), get those
  val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
  if (!rddPrefs.isEmpty) {
    return rddPrefs.map(host => TaskLocation(host))
  }
  // If the RDD has narrow dependencies, pick the first partition of the first narrow dep
  // that has any placement preferences. Ideally we would choose based on transfer sizes,
  // but this will do for now.
  rdd.dependencies.foreach {
    case n: NarrowDependency[_] =>
      for (inPart <- n.getParents(partition)) {
        val locs = getPreferredLocs(n.rdd, inPart)
        if (locs != Nil) {
          return locs
        }
      }
    case _ =>
  }
  Nil
}

============================================选择结束============================================

ok,我们回来继续看submitMissingTasks。刚才说到对于Shuffle和最后一个stage,对于每个partition,new了不同的Task加到tasks这个ArrayBuffer里。从这里可以看出每个task是每个Stage的一个partition。然后的那一步我在代码里做了注释,还没看懂。但是大概猜测一下就是如果这个job已经active(其中某个stage已经被执行)了,就放到它被执行的那个pool里这样。接下来我们又要往listenerBus里面丢东西了。丢了个SparkListenerStageSubmitted

刚才说到会把stage根据partition拆成task,到这里就应该判断一下如果tasks.size>0说明有task没有完成,否则就把这个stage从runningStages去掉。然后我们重点关注tasks.size>0的情况。

然后我们要对这个task进行序列化,不然不能提交。然后就taskScheduler.submitTasks(new TaskSet(tasks.toArray, stage.id, stage.newAttempId(), stage.jobId, properties))。OK终于看完stage阶段了,来到了task。看TaskScheduler.scala,然后发现这个类的所有实现都在TaskSchedulerImpl.scala。以下是这个类的一段注释。

====================================插播一段关于TaskScheduler的作用==============================

这是一段TaskSchedulerImpl的注释。

Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
It can also work with a local setup by using a LocalBackend and setting isLocal to true.
It handles common logic, like determining a scheduling order across jobs, waking up to launch
speculative tasks, etc.

Clients should first call initialize() and start(), then submit task sets through the
runTasks method.

THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
threads, so it needs locks in public API methods to maintain its state. In addition, some
SchedulerBackends synchronize on themselves when they want to send events here, and then
acquire a lock on us, so we need to make sure that we don't try to lock the backend while
we are holding a lock on ourselves.

//////======================================关于那个Backend=================================///////

突然想要插播一下关于Backend什么时候被初始化到TaskSchedulerImpl里的。这段代码在SparkContext.scala里,有一个createTaskScheduler这个函数的作用是Creates a task scheduler based on a given master URL. Extracted for testing.这里猜一下它是在master被启动的时候就调用了的。这里面有对集群启动的方式进行判断,比如local啊mesos啊yarn啊什么的,我们平时用的应该就是spark自带的standalone,所以是SPARK_REGEX。在这里会有

val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls)
scheduler.initialize(backend)
scheduler  

可以看到使用的backend是在scheduler.client里的SparkDeploySchedulerBackend这是一个endpoint for executors to talk to us猜一下就是为worker准备的向master报告完成任务的endpoint

//////=====================================Backend结束=====================================//////

总之一句话,这些东西都是在master刚被开启的时候就new好了的!

=====================================TaskScheduler作用介绍结束=================================

看一下刚才调用的submitTasks

override def submitTasks(taskSet: TaskSet) {
  val tasks = taskSet.tasks
  logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
  this.synchronized { 
    val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
    activeTaskSets(taskSet.id) = manager
    schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

    //isLocal默认为false
    if (!isLocal && !hasReceivedTask) {
      starvationTimer.scheduleAtFixedRate(new TimerTask() {
        override def run() {
          if (!hasLaunchedTask) {
            logWarning("Initial job has not accepted any resources; " +
              "check your cluster UI to ensure that workers are registered " +
              "and have sufficient memory")
          } else {
            this.cancel()
          }
        }
      }, STARVATION_TIMEOUT, STARVATION_TIMEOUT)
    }
    hasReceivedTask = true
  }
  backend.reviveOffers()
}

这里有一个synchronized,可以猜到这段可能需要被很多线程共享,同一时刻只能有一个现成能够用这段代码块。首先,new了一个TaskSetManager并把这个manager与某一个TaskSet的id相对应。然后把这个manager放到一个schedulableBuilder里。然后进入了一个if语句的判断,isLocal默认为false就不多说了。hasReceivedTask是一个被@volatile修饰的变量,线程在每次使用该变量的时候,都会读取变量修改后的值。

然后从backend.reviveOffers()开始卡看,这个函数在CoarseGrainedSchedulerBackend里,通过!发消息的方式递交给driverActor,然后driverActor调用makeOffers()方法。

// Make fake resource offers on all executors
def makeOffers() {
  launchTasks(scheduler.resourceOffers(
    executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
}

注意到这里开始向executors launch tasks了。

// Launch tasks returned by a set of resource offers
def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
  for (task <- tasks.flatten) {
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val serializedTask = ser.serialize(task)
    if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
      val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
      scheduler.activeTaskSets.get(taskSetId).foreach { taskSet => //这里面存的是TaskSetManager
        try { 
          var msg = "Serialized task %s:%d was %d bytes which " +
            "exceeds spark.akka.frameSize (%d bytes). " +
            "Consider using broadcast variables for large values."
          msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize)
          taskSet.abort(msg)
        } catch {
          case e: Exception => logError("Exception in error callback", e)
        }
      }
    }
    else {
      freeCores(task.executorId) -= scheduler.CPUS_PER_TASK
      //看这里,看这里
      executorActor(task.executorId) ! LaunchTask(new SerializableBuffer(serializedTask))
    }
  }
}

首先对于每个task都要序列化(序列化了才能发出去啊),然后如果出了个什么问题(跟size有关,就是那个if),对于每一个TaskSetManager,新建一个msg,执行taskSet.abort(msg)。对于没问题,就先把这个Task需要用到的cpu核的数量从freeCores中减去(这里说明master对整个集群的core的数量有一个全局观,然后可以分配这样,至于怎么分配可以再研究一下),然后executorActor(task.executorId) ! LaunchTask(new SerializableBuffer(serializedTask))。executorActor是CoarseGrainedSchedulerBackend中的一个HashMap[String, ActorRef]。因此,从executorActor(task.executorId)得到的实际上是一个ActorRef。这是一个akka的数据结构。那么这个消息发到哪里去了呢?答案是在executor.CoarseGrainedExecutorBackend.scala。这个类是在worker上启动了的,具体可以参考这篇文章。所以其实就相当于把序列化以后的Task通过akka发送给了worker。顺便说一句,我们就这么从scheduler包来到了executor包。

当一个worker收到了一个task它应该做什么呢?我们来看一下CoarseGrainedExecutorBackend.scala

override def receive = {
  ...
  case LaunchTask(data) =>
    if (executor == null) {
      logError("Received LaunchTask command but executor was null")
      System.exit(1)
    } else {
      val ser = SparkEnv.get.closureSerializer.newInstance()
      val taskDesc = ser.deserialize[TaskDescription](data.value)
      logInfo("Got assigned task " + taskDesc.taskId)
      executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)//看这里,看这里
    }
  ...
}

这段代码表示收到了一个LaunchTask的消息,参数是刚才传进来的序列化之后的task。如果没发生什么问题的话,首先要把这个task反序列化。然后调用executor.launchTask开始进行这个task。这个executor在Executor.scala里。调用launchTask方法之后生成一个TaskRunner

def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
  val tr = new TaskRunner(context, taskId, serializedTask)
  runningTasks.put(taskId, tr)
  threadPool.execute(tr)
}

这里new了一个TaskRunner,这是一个Executor的内部类,继承了Runnable,然后存到runningTasks里(这是一个ConcurrentHashMap,存着现在在跑的task们)。然后放到threadPool里执行这个TaskRunner。我们来看一下这个run()

override def run(){
    ...
    val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
    updateDependencies(taskFiles, taskJars)
    task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
    ...
    val value = task.run(taskId.toInt)
    //然后后面就是序列化结果传出来啊什么的,先放一下。
}

首先对Task反序列化,得到这个task依赖的jar包和file,然后使用updateDependencies(taskFiles, taskJars)。把这些依赖都拉过来。如果一切正常,调用val value = task.run(taskId.toInt)。值得一提的是,这里会对java的GC时间进行测量。

这就是Spark任务调度直到Worker开始执行的全部过程了。下次我们来讲一讲Worker上是怎么执行这些任务的。

参考文献
(都是carolz乱找的一些别人写的博客,但是感觉写得很不错~paper什么的没有具体源码解读,就不列了,反正就那么几篇,大家都知道):

ColZer’s Github
Spark Source Code Study - Hanwei feeds dog
Spark internal - 多样化的运行模式(上)