Tensorflow XLA Service Buffer优化详解

下图是Tensorflow 架构图,以及XLA在Tensorflow中的位置

通过层层转换(参考Tensorflow XLA Service 详解 I),Graph在进入XLA Service前已经被表达为HloModule的形式,而作为图编译器的核心,XLA Service就负责将HloModule表达的计算图编译为可以直接在不同硬件平台(Backend)执行的程序,而编译的核心,就是优化代码,包括设备无关的优化和设备相关的优化:

  1. 优化HloModule所表示的计算图,并将其转化为LLVM HLO
  2. 基于LLVM,生成硬件相关的二进制

作为通用的编译器框架,LLVM 会对LLVM HLO做大量的优化,生成高效的Target binary, 所以作为XLA的开发者,主要关注阶段1的优化: 设备无关的图优化算法。Compiler是XLA适配硬件的接口类,每个适配XLA的硬件都必须实现其中的方法。尤其是RunBackend(),参考Tensorflow XLA Service 详解 I 一文,该接口是进行图优化和编译的入口,统领整个优化和编译过程。

同时,XLA Service还实现了一组通用的优化方法(各种Schedule策略,各种Memory优化算法)供各个硬件平台的编译器使用,当然,主要是供给RunBackend()调用。以XLA的GPU平台的编译优化流程为例:

NVPTXCompiler::RunBackend()
  hlo_schedule = GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)
  BufferAssigner::Run(hlo_schedule->ConsumeHloOrdering()...)
  entry_computation->Accept(&ir_emitter)
  CompileToPtx() 
  CompilePtxOrGetCachedResult()

-1- 从XLA Service通用层中选择适合GPU的Schedule策略
-3- 基于Schedule策略,进行设备无关的Buffer优化,主要关注尽可能的减少Buffer的大小。注意,这里是设备无关的优化,是无法利用硬件Memory特性的。
-4- 将HloModule转化为LLVM IR
-5,6- 利用LLVM框架,将LLVM IR编译为二进制代码。

本文主要关注-3-,是XLA优化的核心。

对BufferAssigner::Run()进一步分解。

NVPTXCompiler::RunBackend()
  hlo_schedule = GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)
  //this analysis figures out which temp buffers are required to run the computation
  BufferAssigner::Run(hlo_schedule->ConsumeHloOrdering()...)
    assigner.CreateAssignment(HloModule, hlo_ordering, buffer_size)
      liveness = BufferLiveness::Run()
      assignment = new BufferAssignment(module, liveness, ...)
      set<LogicalBuffer*> colocated_buffers
      set<BufferAllocation::Index> colocated_allocations
      vector<ColocatedBufferSet> colocated_buffer_sets
      BuildColocatedBufferSets(&colocated_buffer_sets)
      colorer_(assignment->liveness())
      AssignColocatedBufferSets(colocated_buffer_sets, assignment, &colocated_buffers, &colocated_allocations);
      GatherComputationsByAllocationType(module, &thread_local_computations, &global_computations)
      for computation : global_computations:
        AssignBufferForComputation(computation, false, buffers_to_assign_sequentially)
      AssignBuffersWithSequentialOrdering(buffers_to_assign_sequentially, ,assignment)
      for computation : thread_local_computations:
        AssignBuffersForComputation()
      for buffer : assignment->liveness().maybe_live_out_buffers():
        if assignment->HasAllocation(buffer):
          assignment->GetMutableAssignedAllocation(buffer).set_mayby_live_out(true)
      assignment->CombineTempAllocations()
      return std::move(assignment)
  entry_computation->Accept(&ir_emitter)
  CompileToPtx() 
  CompilePtxOrGetCachedResult()

-6- 进行BufferLiveness分析,分析整个HloModule的LogicalBuffer的干涉关系,为后续优化提供依据
-11- BuildColocatedBufferSets, 依据Bufferliveness的分析,将所有的LogicalBuffer分为几个Bufferset,并进行初步的Set融合,每个Bufferset内
参照注释, colocated buffer sets, 每个set都是一组可以共享BufferAllocation的LogicalBuffer, 共享Allocation,意味着共享同一块物理内存(GPU的显存)
-12- colorer_ 缺省被赋值为BufferLiveness::DefaultColorer(), 所有的LogicalBuffer实例的color都会被设置为0
-13- AssignColocatedBufferSets, 为Bufferset分配BufferAllocation, 每一个LogicalBufferSet 与其关联, 这里用到了buffer_size_, 这个函数是判断一个LogicalBuffer大小, LogicalBuffer的大小要和相应的Allocation一样, 具体可以参考tf2xla/while_op.cc tf2xla/if_op.cc xla/client/builder.cc kConditional代码,可以看到明显的要求各个body的Shape要一致。通过TEST用例也能确认
-14- GatherComputationsByAllocationType,根据内含的LogicalBuffer的属性,将Allocation分为global和thread local两类,这部分是理解显存优化的关键,后文详细
-16- AssignBufferForComputation,关联Allocation和XlaComputation,此调用点只针对global,temp buffer被收集到buffers_to_assign_sequentially, 延后处理,

BufferLiveness::Run()

所以,BufferLiveness, 就是获取Buffer生命周期关系, 以便决定Buffer复用策略. 源码整体调用栈如下, 该阶段进一步也分3个过程: LogicalBufferAnalysis->TuplePointsToAnalysis->BufferLiveness

BufferAssigner::Run()
  assigner::CreateAssignment()
    liveness = BufferLiveness::Run(module, std::move(hlo_ordering)  //class BufferLiveness
      liveness = new BufferLiveness()
      liveness->Analyze()
        points_to_analysis_ = TuplePointsToAnalysis::Run()
          logical_buffer_analysis = LogicalBufferAnalysis::Run()
            analysis = new LogicalBufferAnalysis()
            analysis.Analyze()
            return analysis
          analysis = new TuplePointsToAnalysis(logical_buffer_analysis)
          analysis.Analyze()
          return analysis
        maybe_live_out_buffers_ = points_to_analysis_->GetPointsToSet(root).CreateFlattenedSet()
      return liveness
    assignment(new BufferAssignment(module, std::move(liveness)

-1- Memory优化入口
-3-16- 获取BufferLiveness, 对比-16-, 获取到的BufferLiveness会用于支撑Buffer优化决策,用OOP的方式优雅的实现pipeline
-8- 获取LogicalBufferAnalysis
-9- 根据HloInstruction的Shape, 构造一组相应的LogicalBuffer, Shape本身是一个树结构,表示一个HloInstruction的输出形状, 一个LogicalBuffer对应一个Shape的一个节点(Subshape), 表示一块逻辑buffer, 用pair<HloInstruction*, ShapeIndex>来索引一个LogicalBuffer。
-10- 作为Liveness分析的原材料-LogicalBuffer已经构造完毕, 存储下来, 并返回analysis, 准备进入下一阶段, 进行TuplePointsToAnalysis.
-11- 用存储有所有LogicalBuffer信息的LogicalBufferAnalysis实例构造TuplePointsToAnalysis实例, Tuple, 用来描述Buffer树状结构的方式, E = {%1, %2, {%3, %4}} 表示这样一个树:深为3, 深1的节点有一个, 树的根, 深为2的节点有3个, %1, %2, 没体现名字的节点暂且叫”X”, 深为3的节点有2个, %3, %4, 这两个节点是上一层中”X”的子节点. PointsTo, “指向”, 在前面的例子中, root的”PointsTo”就是%1, %2和”X”, “X”的”PointsTo”就是%3和%4, 所以TuplePointsToAnalysis就是分析整个计算图中的LogicalBuffer依赖关系并存储在PointsToSet中,后面会详细分析
-12- 执行分析逻辑.
-13- 对计算图中的LogicalBuffer的依赖关系分析完毕, 存储下来, 并返回analysis, 准备进入下一阶段, 进行BufferLiveness
-14- 用TuplePointsToAnalysis实例获取root的alias_buffer, 也就是潜在的需要传出的HloModule的Buffer, 存储在maybe_live_out_buffers_中.
-15- 返回liveness实例, TuplePointsToAnalysis实例会存储LogicalBuffer的依赖关系, 但BufferLiveness并不会存储每一个LogicalBuffer的”liveness”, 而是基于TuplePointsToAnalysis封装了一组判断特定LogicalBuffer的函数.
-16- 将BufferLiveness实例传入构造LogicalBuffer与BufferAllocation映射关系的BufferAssignment实例.

上面是整个BufferLiveness分析的主干流程, 下面介绍一些关键细节

LogicalBufferAnalysis

class LogicalBufferAnalysis : public DfsHloVisitorWithDefault

LogicalBufferAnalysis “is a” DfsHloVistorWithDefault, 后者是HLO层提供的以visitor模式(没错, 就是设计模式中的Visitor模式)遍历HloModule/HloComputation/HloInstruction的接口. 是XLA处理HLO优化的基础工具, 同时对于二次开发而言, 也是理解代码以及Debug的趁手工具. 对于一个”DfsHloVisitorWithDefault”, 我们主要关心就是其各种”Handle”的实现, 同时, 也关心LogicalBufferAnalysis作为子类更加specific的部分.

//tensorflow/compiler/xla/service/logical_buffer_analysis.h
Status DefaultAction(HloInstruction* hlo_instruction) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleDomain(HloInstruction* domain) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;

-2- 缺省Handle, 遍历HloInstruction的每一个Subshape, 调用NewLogicalBuffer()为其构造LogicalBuffer, 每个LogicalBuffer都有LogicalBufferAnalysis内唯一的id, 由于此时一个Compilation Graph内也只有一个LogicalBufferAnalysis实例, 也就是说这个id是HloModule内唯一的. 构造的LogcialBuffer会统一存储在'logical_buffers_''output_buffers_'中, 二者的主要区别是索引方式不同, logical_buffers_以unique_ptr的方式存储原始的LogicalBuffer实例, 后者以std::pair<HloInstruction*, ShapeIndex>为Key, 进行Hash查找, 可快速进行LogicalBuffer检索。如果查看了ShapeIndex的实现细节,可以看到其Index并不是一个单独的”id”而是一个”absl::InlinedVector<int64, 2> indices_”, 这里indices_中的每一个int64都代表HloInstruction对应的Shape多叉树的一个节点。 ShapeIndex只负责Index的部分, 至于”Leaf”是什么,由用户决定,这里,LogicalBufferAnalysis的Leaf就是LogicalBuffer,使用hash建立ShapeIndex和LogicalBuffer的关系。

-3:11- 显然, 9个”non default”的Handle都是因为缺省DefaultAction的处理方式不适用于自己所负责HloOpcode, 里面的注释很清晰, 嗯, 如果已经理解TuplePointsToAnalysis的话. LogicalBufferAnalysis和TuplePointsToAnalysis分属不同文件主要是由于使用了Visitor的处理方式, 从单个HloInstruction或者单个Buffer的角度, 设计上一脉相承, 通常要一起分析. 这里, 从理解代码的角度, 可以重点关注以下几个Handle:
-3- HandleTuple, NewLogicalBuffer(tuple, /*index=*/{});, 分配以”{}”为索引的LogcialBuffer, “{}”, 看起来并不符合”ShapeIndex”的含义, 看起来没有任何”形状”信息, 实际上, 这里只需了解: 以”{}”为索引是因为kTuple只需要一个”top-level buffer”, 至于什么是”top-level buffer”呢? 我在TuplePointsToAnalysis中详细分析.
-5:6- HandleBitcast 和 HandleDomain, 不构造任何LogicalBuffer.
-7- HandleCopy, 单纯构造一个新LogicalBuffer, 同样, 一个”top-level buffer”
-8- HandleRecvDone, 两个LogicalBuffer, 位于{}的top-level buffer用于以buffer alias的形式存储接收的数据, 位于{1}的buffer用于存储交互token
-9- HandleSend, 3个LogicalBuffer, 位于{}的top-level buffer用于以buffer alias的形式存储即将发送的数据, 位于{1}的buffer用于存储交互context, 位于{2}的buffer`用于存储交互token

了解了Visitor的部分, 现在来看下Specific的部分.

Status LogicalBufferAnalysis::Analyze();
  // We filter out fusion computations, and get to them through fusion
  // instructions. This is because it's possible to have orphaned (unreachable)
  // fusion computations, and we don't want to try to assign buffers to those.
  std::vector<HloInstruction*> fusion_instructions;
  for (auto* computation : module_->MakeNonfusionComputations()):
    computation->Accept(this);
    for (auto* instruction : computation->instructions()):
      if (instruction->opcode() != HloOpcode::kFusion):
        continue;
      GatherFusionInstructions(instruction, &fusion_instructions);
  for (auto* instruction : fusion_instructions):
    instruction->fused_expression_root()->Accept(this);

上面面这个函数是执行遍历的入口, 其中对kFusion进行了filter out的处理, 注释写的很清楚, 对计算图进行”剪枝”, 不给 unreachable fusion computations分配LogicalBuffer, 避免浪费内存, 这样的处理在XLA很多代码中都有遇到. 这里有一个经典问题, TF runtime前期进行图处理的时候, 已经将unreachable的节点去掉了, 为什么XLA这里还要处理一遍呢? 理解的关键点在Op->HloInstruction的映射, XLA用少数几个kOpcode就组合成了几百个Op类型, 虽然TF前期已经将unreachable op去掉了, 但XLA将Op组成的计算图转化为HloInstruction组成的计算图之后, 又会出现新的unreachable节点, 这里的节点不再是Op, 而是粒度更小HloInstruction. 但即便粒度不同, TF传统引擎和XLA执行的很多图优化算法是类似的, 而二者并没有在更高的层次上进行图优化算法的抽象, 导致了这种功能重复实现. 有些同学认为从OOP的角度, 这里存在优化空间, 我的理解是恰恰相反, 这里不但不需要进一步抽象引入依赖, XLA最好将自己完全从TF引擎解耦, 以外部插件的形式接入到TF核心层. 大家感兴趣的可以给我留言讨论.

TuplePointsToAnalysis

理解这个Analysis的主要难点在于理解”TuplePointsTo“的含义, 理解了含义, 其内部的设计逻辑就一目了然. 对于一个HloInstruction而言, “PointsTo“是表达的内容: 该HloInstruction所依赖的Buffer, “Tuple“是表达形式: Buffer和Buffer间的拓扑关系用"{..., {...}}"的形式表达.

源码中, 下面这段注释有很好的解释:

  // For the subshape at the given index (where index is defined as in
  // ShapeUtil::GetSubshape) this method returns the set of HLO instructions
  // which may produce the tuple subshape at that index. For example, given:
  //
  // %tuple1 = tuple(...)
  // %tuple2 = tuple(...)
  // %select = select(%tuple1, %tuple2)
  // %nested_tuple = tuple(%select, %tuple1)
  //
  // These are the values for tuple_sources() for the PointsToSet of
  // %nested_tuple:
  //
  // tuple_sources({}) = {%nested_tuple}
  // tuple_sources({0}) = {%tuple1, %tuple2}
  // tuple_sources({1}) = {%tuple1}
  //
  // tuple_sources() at the index of an array shape (not a tuple) returns the
  // empty set. The instructions in the set returned by tuple_sources
  // necessarily are either Tuple instructions, constants, or parameters.

转化成图示, 左图就是上例图示,其中蓝色箭头就是”PointsTo”的含义, 表示的buffer索引, 所有的箭头和它对应的buffer们就是kTuple的”PointsToSet”, 个人认为, 这里使用kTupleSelect并不适合作为例子, 因为它本身不是大多数情况使用的DefaultAction, 而是和kTuple一样对输入的Buffer进行alias的操作, 所以简单画了右图, 希望能表达清楚。

介绍了”TuplePointsTo”的含义,接下来从源码角度分析其实现细节。

// A class describing the source(s) of the Buffer(s) contained in the output of
// a particular HLO instruction. The structure of PointsToSet mirrors the
// structure of the instruction's shape, which may be an arbitrary tree (eg, a
// nested tuple). Each node in this tree corresponds to a single buffer in the
// instruction's output and contains the set of Buffers which might define
// the corresponding buffer.
class PointsToSet {
 private:
  struct Elem {
    BufferList buffers;
    SourceSet tuple_sources;
  };
  ShapeTree<Elem> tree_;
};
// DFS visitor that performs tuple points-to analysis. This analysis determines
// the potential sources of each buffer in each instruction's output.
class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
  // Information kept per instruction
  struct PerInstruction {
    std::unique_ptr<PointsToSet> points_to_set;
    // Empircally, ~92% of instructions have 1
    // instruction_defined_buffer, and 99% have 0 or 1
    BufferDefinitionVector instruction_defined_buffers;
  };

  // The logical buffers for this module.
  const std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis_;

  // A map from instruction->unique_id() to
  absl::flat_hash_map<int, std::unique_ptr<PerInstruction>> per_instruction_;

  // A map from LogicalBuffer->id() to alias information about that logical
  // buffer
  std::vector<BufferAliasVector> logical_buffer_aliases_;

同样, TuplePointsToAnalysis也是DfsHloVisitorWithDefault, 而对于所有的Visitor, 我们都主要关注其Handle实现, 细心的同学可能发现了, 这里的Handle名与LogicalBufferAnalysis实现的一模一样, 事实上, TuplePointsToAnalysis的几个Handle实现其功能依赖于LogicalBufferAnalysis中对应的Handle的实现, 所以在看这部分代码时, 如果TuplePointsToAnalysis的某个Handle不理解, 可以去LogcialBufferAnalysis中的同名Handle看一下.

Status DefaultAction(HloInstruction* hlo_instruction) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleDomain(HloInstruction* domain) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleAddDependency(HloInstruction* add_dependency) override;

-1- DefaultAction, LogicalBufferAnalysis的DefaultAction()根据Shape构造LogicalBuffer实例, 这里的DefaultAction()取其结果, 将每一个LogicalBuffer都存储在PointToSet的buffers中, 此外, 对于TupleShape的HloInstruction, 还将该HloInstruction作为其上述buffers的”source”. 至于什么样XlaOp会将Shape设置为Tuple, 有这些: Infeed, InfeedWitHToken, RecvFromHost, RecvWithToken, SendToHost, SendWithToken, 都是用于大量传输数据的XlaOp. 构造PointToSet实例是作为TuplePointsToAnalysis作为DfsHloVisitorWithDefault的主要工作.
-2- 在LogicalBufferAnalysis的HandleTuple()基础上, 构造PointsToSet实例, 注意到, 1. Handle内获取LogcialBuffer的索引与LogicalBufferAnalysis中存入LogcialBuffer的索引一致, 2. 对于Tuple来说, 它的PointsToSet就是每一个operand的PointsToSet集合, 包括buffers和buffers对应的source. 这一点后面在BufferLiveness中也会用到.
-3:10- 在LogicalBufferAnalysis相应Handle的基础上, 构造相应的PointToSet实例

共性部分分析完, 同样, 下面来看下TuplePointsToSet的Specific部分. 如前文所述, Analyze()方法是入口.

TuplePointsToAnalysis::Analyze()
  std::vector<HloInstruction*> fusion_instructions
  for computation : module_->MakeNonfusionComputations()):
    computation->Accept(this)
    PopulateDefinedBuffersAndAliases(computation->instructions())
      for instruction : instructions:
        PerInstruction* pi = PerInst(instruction)
        GatherBuffersDefinedByInstruction(instruction, &pi->instruction_defined_buffers))
      points_to_set = GetPointsToSet(instruction)
      points_to_set.ForEachElement(
        [this, &instruction](const ShapeIndex& index,const PointsToSet::BufferList& pointed_to_buffers) {
          for (const LogicalBuffer* buffer : pointed_to_buffers):
            logical_buffer_aliases_[buffer->id()].emplace_back(instruction, index)
        })
    for instruction : computation->instructions():
      if instruction->opcode() == HloOpcode::kFusion:
        GatherFusionInstructions(instruction, &fusion_instructions)
  // Run points-to analysis on fusion instructions in 'computation'.
  for instruction : fusion_instructions:
    instruction->fused_expression_root()->Accept(this))
    PopulateDefinedBuffersAndAliases(instruction->fused_instructions())

-3- 这里和LogicalBufferAnalysis类似, 同样进行的”剪枝”处理.
-5- 除了通过Accept()执行Handle, TuplePointsToAnalysis还会对Handle处理的结果通过PopulateDefinedBuffersAndAlias()进一步处理, 形成两个存储结构: 'TuplePointsToSet::PerInstruction::instruction_defined_buffers' 'TuplePointsToAnalysis::logical_buffer_aliases_', 一个instruction_defined_buffers是通过GatherBuffersDefinedByInstruction()构造的, 里面依然是一个对所有SubShape遍历操作, 经过之前的过程, 不是可以轻松的获得一个HloInstruction对应的LogicalBuffer, 为什么这里还要遍历一次呢? 如果您有这样的疑问, 建议看下HandleTuple()等TuplePointsToAnaysis的”定制”Handle, 对于 kTuple 等几个HloOpcode, 它的PointsTo的source可不是本身, 而是其operand的PointsTo的source, 显然, 这类HloInstruction没有自己的”defined buffers”, 而后续进行allocation assignment的前提条件就是”buffer is defined by instruction”, 这里即是挑出会被assigned的”buffer defined by instruction”, 此外,代码的实现还顺便做了合法性检查: kTuple etc类的HloInstruction的所有buffers的source, 都不会是自己. 至于logical_buffer_aliases_, 可以理解为一个反向检索, Alias, 别名, 也是这个意思, 可以方便用户根据LogicalBuffer->id, 快速的索引到对应的HloInstruction和ShapeIndex.

BufferLiveness

经过LogicalBufferAnalysis构造原材料(LogicalBuffer), TuplePointsToAnalysis构造原材料之间的关系(PointsToSet), 大部分苦活已经完成, BufferLiveness进过封装就可以对外提供判断一个LogcialBuffer的Liveness的接口, 主要以下几个:

  // Returns true if the live range of the buffer containing the output of 'a'
  // may overlap with the live range of the buffer of 'b'. If instruction 'a'
  // interferes with instruction 'b' then they cannot share the same buffer.
  bool MayInterfere(const LogicalBuffer& a, const LogicalBuffer& b) const;

  // Returns true if the buffer for the given instruction may be live out of the
  // module. That is, the instruction's buffer may be included in the output of
  // the entry computation.
  bool MaybeLiveOut(const LogicalBuffer& buffer) const;

  // Returns the complete set of buffers that may be live out of the module.
  const PointsToSet::BufferSet& maybe_live_out_buffers() const;
  // Returns the underlying points-to analysis used for this liveness analysis.
  const TuplePointsToAnalysis& points_to_analysis() const;

这里还是关注几个小细节.

Status BufferLiveness::Analyze() {
  points_to_analysis_ = TuplePointsToAnalysis::Run(module_)
  for computation : module_->computations():
    if computation->IsFusionComputation():
      continue;
    // Gather all instructions whose buffers might alias other instructions into
    // the set aliased_buffers_.  This includes those contained as a tuple
    // element in other instruction's output.
    for instruction : computation->instructions():
      for LogicalBuffer* aliased_buffer : points_to_analysis_->GetPointsToSet(instruction).CreateFlattenedSet():
        if (aliased_buffer->instruction() != instruction):
          aliased_buffers_.insert(aliased_buffer);
    if computation == module_->entry_computation():
      HloInstruction* root = computation->root_instruction();
      maybe_live_out_buffers_ = points_to_analysis_->GetPointsToSet(root).CreateFlattenedSet();

-11- 如果一个HloInstruction->PointsToSet中->LogicalBuffer的source不是该HloInstruction, 就表示该HloInstruction是类Tuple的, 没有实质的LogicalBuffer, 也不是任何LogicalBuffer的source, 将其存储在alias_buffers_
-15- live_out_buffer就是生命周期长于entry_computation(或者理解为HloModule)的Buffer, 就是承载整个entry_computation输出数据的Buffer, root作为整个程序的总入口和总出口, 其PointsToSet的LogicalBuffers就是整个程序总输出Buffer不难理解, 但这里又为什么叫”maybe”? 代码的注释是, 后续进行执行流优化的时候, 会对hlo_ordering调整, 那么就有可能root不再作为最后的HloInstruction, 所以叫”maybe”. 但根据我对XLA的理解, 没有想到哪些优化会导致root被提前, 也许只是这个文件的程序员给自己留个后路, 毕竟万一有呢, 当然, 更可能是我忽略了一些细节, 理解不深(TODO).

下面两个是BufferLiveness和核心方法, 在后续的优化被频繁调用, 逻辑简单, 就不展开

  // Returns true if the live range of the buffer containing the output of 'a'
  // may overlap with the live range of the buffer of 'b'. If instruction 'a'
  // interferes with instruction 'b' then they cannot share the same buffer.
  bool MayInterfere(const LogicalBuffer& a, const LogicalBuffer& b) const;
  // Returns true if the buffer for the given instruction may be live out of the
  // module. That is, the instruction's buffer may be included in the output of
  // the entry computation.
  bool MaybeLiveOut(const LogicalBuffer& buffer) const;

BuildColocatedBufferSets()

这个方法的核心是将可以共享的Memory Chunk的LogicalBuffer挑出来,存在同一个BufferSet里

// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
// in the same allocation (currently just supports kWhile, kCall, and
// kConditional and input output aliasing).
void BufferAssigner::BuildColocatedBufferSets(

这句注释能看出几点信息:
1. 相同的buffer set会共享同一个allocation,即共享同一块物理内存
2. 当前只针对kWhile,kCall,和kConditional,而这三个HloOpcode的共同点是都有called_computations — 是连接entry_computation和其他computation的”铰链”
3. 该函数只会进行computation-wide 的合并,即:不会有一个ColocatedBufferSets会cross-computation ColocatedBufferSet。这一点在很多地方都能体现。

BuildColocatedBufferSets()
  module->input_output_alias_config().ForEachAlias({...})
  for computation : module->MakeComputationPostOrder():
    if computation->IsFusionComputation():
      continue
    for instruction : computation->MakeInstructionPostOrder():
      if kWhile:
        ...
      else if kCall:
        ...
      else if kConditional:
        ...
    
  MergeColocatedBufferSets()   
    is_readonly_entry_parameter = [](){}
    set_can_be_merged                     //这是一个vector<bool>
    for i : colocated_buffer_sets.size():
      for buffer : colocated_buffer_sets[i]:
        if buffer_liveness.MaybeLiveOut(buffer) || is_readonly_entry_parameter(buffer) || buffer->instruction()->opcode()==HloOpcode::kConstant:
          set_can_be_merged[i] = false
          break
      
    cannot_merge_buffer_sets = [](){};
    interference_map
    for i : colocated_buffer_sets.size():
      for j : colocated_buffer_sets.size():
        if cannot_merge_buffer_sets(i, j):
          interference_map[i].push_back(j)
          interference_map[j].push_back(i)
    assigned_colors = ColorInterferenceGraph(interference_map)
      c_sort
      assigne_colors
      for node : nodes:
        for 
        for
        assigned_colors[node] = color
        return assigned_colors
    for 
      new_colocated_buffer_sets

-2- 处理Alias的buffer,用ShapeIndex来索引,确保相同Set内的LogicalBuffer的Shape相同,相应的Buffersize也相同
-7,9,11- 将kWhile,kCall,kConditional的buffer都按照”相同ShapeIndex的Buffer归属到相同的Bufferset”的原则构造colocated_set
-15- MergeColocatedBufferSets, 前面都是构造Buffersets,这里尝试能否将多个BufferSet进一步合并,之前每个BufferSet的构造都是独立的,
-23 cannot_merge_buffer_sets, 判断两个Bufferset能否合并,是两个bufferset能合并的”充要条件”
-24- interference_map, 标记Bufferset之间的”cannot be merged”的关系
-24:29- 使用邻接矩阵法表示两个BufferSet是否有干涉。比如,以下面的inference_map为例:

0:1,2
1:0,3
2:0
3:1

该map表示BufferSet0和BufferSet1,BufferSet2 都是cannot be merged,BufferSet3同理
-30:39- ColorInterferenceGraph,为inference_map里的set标记color,相同color的表示可以merge,用color是为了计算最终需要多少Bufferset。最终就会生成merge好的colocated_buffersets,注意,目前为止,这里还只有kWhile,kConditional,kCall的Bufferset。一开始进行按照冲突的set的数量进行排序,优先处理冲突最多的。进行降序,这里使用了启发式的算法,不同的排序会导致最终划分的set不同,启发式,意味着不一定是”全局最优”,按照冲突的多寡排序,会保证冲突多的color值比较小,color值越大,冲突越小,而按照相同的color合并,前面的能合并的少,后面的能合并的多。举个例子:

1:{4,5}
2:{4,5}
3:{4}
4:{1,2,3}
5:{1,2}

问,4,5能否合并

答,这个其实看情况,如果3 的color比1,2小,那么4和5就不会是同一个color,也就不会合并,而如果不是这样,那么4和5就是相同的color,也就可以合并4的color和1,2怎么确定呢?这里将整体排序,按照冲突程度排序,冲突越大,color越小在本例中,4的color最小,5和4的color一样将冲突大的往前排,先分配好color,这样对于4和5这种情况,会使4和5在一个set中

AssignColocatedBufferSets()

在Computation内部,为ColotatedBufferset分配BufferAllocation,此时,ColocatedBufferset只有Alias,kCall,kWhile,kConditional几个HloOpcode

// For each buffer set in 'colocated_buffer_sets', assigns all buffers in the
// same set to the same buffer allocation in 'assignment'.
void AssignColocatedBufferSets(
AssignColocatedBufferSets()                         //Assign, 赋值, Logical->Allocation
  for colocated_buffer_set : colocated_buffer_sets:
    BufferAllocation* allocation = nullptr
    for buffer : colocated_buffer_set:
      if instruction->opcode() == HloOpcode::kParameter && computation == computation->parent()->entry_computation:  
        entry_parameter...
      else if instruction->opcode() == HloOpcode::kConstant:
        is_constant = true  
    for buffer : colocated_buffer_set:
      if  allocation == nullptr:
        // TODO(b/32491382) Avoid current trivial solution of using new
        // allocations for each colocated buffer set. When liveness has
        // module-level scope, we can allow buffers to be shared across
        // computations (in some cases).
        allocation = assignment->NewAllocation(*buffer, buffer_size)
        if is_constant:
          allocation->set_constant(true)
        colocated_allocations->insert(allocation->index())
      else:
        assignment->AddAssignment(...offset = 0, buffer_size) 
      colocated_buffers->insert(buffer)        
    if entry_parameter_number >= 0:
      parameter_has_alias = ...
      allocation->set_entry_computation_paramter()

-2- 遍历 colocated_buffer_sets 中的每一个colocated_buffer_set中的每一个Logical buffer
-3- BufferAllocation* 每个colocated_buffer_set共享一个BufferAllocation
-5,7- 如果一个ColocatedBufferSet中有一个用于kParameter和kConstant的LogicalBuffer,那么整个BufferAllocation都要配置相应的域
-9- 遍历每一个 colocated_buffer_set, 如果是Set中的第一个LogicalBuffer,就通过NewAllocation创建BufferAllocation实例并保存在assignment中,否则就将LogicalBuffer关联已经构造好的BufferAllocation实例中。这里注意,这里的offset都是0,意味着同一个ColocatedBufferSet中的LogicalBuffer都是每块Memory的0地址开始使用,这也符合ColocatedBufferSet的定义,同一个ColocatedBufferSet内的任意两个Buffer都在时间上不干涉的,配置同一个地址,没有任何问题。注释: 由于当前BufferLiveness的分析都是在HloComputation内部,所以基于BufferLiveness的ColocatedBufferSets分析和BufferAllocation的构造,也不会computation-level scope的,这块还有优化空间,可以进一步压缩BufferAllocation的需求。

GatherComputationsByAllocationType()

HloComputation 在概念上类比”函数”,这个接口按照其返回reference和value的不同,将其分为global和thread-local两类,划分标准十分粗暴,还有很大的优化空间:

  1. 如果一个HloComputation实例返回reference,那么该实例就是global,其内部的所有BufferAllocation,无论是否是”局部变量”,均当做”全局变量”处理(set_thread_local(false)),一同分配,优化,生命周期也和全局变量一样。体现在BufferAllocation的属性上,就是is_thread_local == false。这种划分方法会导致global的HloComputation内的生命周期较实际需要的状况被延长了,造成浪费,但能大幅简化优化逻辑。
  2. 如果一个HloComputation实例返回value,那么该实例就是thread-local,其内部的所有BufferAllocation,必须全部是”局部变量”(set_thread_local(true)),简化了Buffer生命周期的管理,随着HloComputation的执行结束,”局部变量”会被回收。也正是由于栈变量可以反复的申请释放,也就无需像global一样进行全局的优化。全部是”局部变量”的约束,显得僵硬,但能大幅简化优化逻辑。

可以类比线程来理解这两个概念: global_computations –> 主线程,thread_local_computations -—> 子线程。但为什么要做这个划分呢?笔者认为可以从以下2点来理解:

  1. 当前的软件层是HLO: High Level Optimization,是硬件无关的,而无论什么执行硬件,都有”执行流”的概念,也有”主执行流”的概念,只不过具体实现不同,比如CPU中有主线程和子线程,GPU中有CUDA Stream。对这些硬件公共的执行流进行抽象,就是HLO的工作,所以,这里既没有说进程,线程,CUDA Stream,而只是说global和thread_local,是一个高度抽象的执行流
  2. HLO要做显存优化,一方面要知道两个Buffer是否存在必需的依赖关系(BufferLiveness),如果没有依赖关系,就要考察两个Buffer是否在同一个执行流,简单的理解,在同一个执行流的Buffer,由于存在先后顺序,通常是可以复用的。
// Walk the call graph of the HLO module and place each computation into either
// thread_local_computations or global_computations depending upon whether the
// computation requires thread-local allocations or global allocations. The
// elements in thread_local_computations and global_computations are in post
// order (if computation A has an instruction which calls computation B, then A
// will appear after B in the vector).  --> thread_local 就是per thread 变量, 线程不安全? 为什么说thread, 哪里会用thread?
GatherComputationsByAllocationType()
  while !worklist.empty():
    if is_thread_local && in_global_set || 
      ...
    if is_thread_local:
      thread_local_set...
    else:
      global_set...
    for instruction : computation->instructions():
      for subcomputation : instruction->called_computation():
        switch instruction->opcode():
          case HloOpcode::kCall:
          case HloOpcode::kConditional:
          case HloOpcode::kWhile:
            worklist.push_back(std::make_pair(subcomputation, false)) 
            break;
          case HloOpcode::kAllReduce:
          case HloOpcode::kMap:
          case HloOpcode::kReduce:
          case HloOpcode::kReduceWindow:
          case HloOpcode::kScatter:
          case HloOpcode::kSelectAndScatter:
          case HloOpcode::kSort:
          case HloOpcode::kFusion:
            worklist.push_back(std::make_pair(subcomputation,true))
          default:
            return InternalError()

    for computation : module->MakeComputationPostOrder():
      if thread_local_set.contains(computation):
        thread_local_computation->push_back(computation)
      else if global_set.contains(computation):
        global_computations->push_back(computation)

-26- 只处理有called_computation的情况
-28:30- 如果instruction是kCall, kConditional, kWhile, 就是global,与AssignColocatedBufferSets()相呼应
-32:39- 对于kAllreduce,kMap,kReduce,kReduceWindow,kScatter,kSelectAndScatter,kSort,kFusion 都会新开thread_local

AssignBuffersForComputation()

为剩下的LogicalBuffer找Allocation,hlo_opcode.h有97个HloCode,之前主要是围绕Alias,kCall,kWhile,kConditional准备Buffer,现在要处理剩下的HloOpcode,尽可能复用之前已经分配的BufferAllocation,复用的算法就是已经分配好global_computation和thread_local_computation

  1. global_computation是”main stream”,是顺序的,尽可能复用已有的BufferAllocation
  2. thread_local_computation 要分配自己的
// Assigns buffers to the instructions in the given computation. "assignment"
// is modified to reflect the new buffer assignments. If is_thread_local is
// true, then all assigned buffers have the is_thread_local flag set to
// true.
AssignBuffersForComputation
  for instruction : computation->instructions():
    for buffer : assignment->points_to_analysis().GetBufferDefinedByInstruction()
      sorted_buffers.push_back()
  for instruction : computation->MakeInstructionPostOrder():
    post_order_position.emplace(instruction, position)
    position++
  absl::c_sort(sorted_buffers,...)
  for buffer : sorted_buffer:
    if colocated_buffers.contains(buffer):
      continue
    if instruction->opcode() == HloOpcode::kConstant:
      BufferAllocation* allocation = assignment->NewAllocation(*buffer, buffer_size)
      allocation->set_constant(true)
      continue
    if is_entry_parameter:
      ...
      continue
    if is_thread_local:
      allocation->set_is_thread_local(true)
      continue
    //global
    if (buffer->IsTopLevel() && !buffer->IsTuple()):
      ...
    ...
    //temp
    if (!assignment->HasAllocation(*buffer) && has_sequential_order &&
        !liveness.MaybeLiveOut(*buffer)):
    (*buffers_to_assign_sequentially)[computation].insert(buffer)
    
    if (!assignment->HasAllocation(*buffer)):
      BufferAllocation* allocation = assignment->NewAllocation(*buffer, buffer_size)

-14- 如果是ColocatedBuffer,那么continue,因为在GatherComputationsByAllocationType()已经为这类(Alias, kCall, kWhile)分配了BufferAllocation(而且还是offset为0的)
-16- 处理kContant,分配新BufferAllocation
-20- 处理kParameter,分配新BufferAllocation
-23- 对于thread local,分配新的BufferAllocation
-24:27- 对于global,尽可能复用
-28- 对于一直没有assign的temp buffer,先登记,延后处理
-33- 剩下依然没有办法复用的,分配新的BufferAllocation

AssignBuffersWithSequentialOrdering()

对于前文规划好的一组顺序分配的BufferAllocation(global),利用HeapSimulation算法,将它们压缩到一块完整的BufferAllocation中,主要针对 global 和 temp

// Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming
// the HLO instructions will be executed in the sequential order given by
// assignment->liveness().hlo_ordering().SequentialOrder. If
// 'run_whole_module_heap_simulation' is true, the heap simulation will be run
// assuming all global computations are sequentially ordered.
AssignBuffersWithSequentialOrdering()
  AssignBuffersFromHeapSimulator()
    BufferAllocation* allocation = assignment->NewEmptyAllocation(result.heap_size, color);

CombineTempAllocations()

优化的最后一步,优化Temp的BufferAllocation

// Combines allocations of temporary buffers of the same color into one big
// BufferAllocation.
CombineTempAllocations():


Related:
Tensorflow XLA Service 详解 II
Tensorflow XLA Service 详解 I
Tensorflow XLA Client | HloModuleProto 详解
Tensorflow XlaOpKernel | tf2xla 机制详解
Tensorflow JIT 技术详解
Tensorflow JIT/XLA UML
Tensorflow Forums

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Discover more from sketch2sky

Subscribe now to keep reading and get access to the full archive.

Continue reading