Tensorflow XLA HLO — BufferLiveness

XLA Sevice 的 High Level Optimization可以分为设备无关优化和设备相关优化两部分, 设备无关优化是Buffer优化, 以便使用更少的主存或显存(以及其他硬件backend), 设备相关的优化主要是根据硬件特性优化执行流, HLO设备无关优化主要分为3个阶段:

  1. 支持分析指定Buffer的Liveness
  2. 构造逻辑Memory对象到物理Memory对象的映射
  3. 优化物理Memory, 交给特定Backend做specific的进一步优化

本文分析第一个阶段–BufferLiveness分析

所以,BufferLiveness, 就是获取Buffer生命周期关系, 以便决定Buffer复用策略. 源码整体调用栈如下, 该阶段进一步也分3个过程: LogicalBufferAnalysis->TuplePointsToAnalysis->BufferLiveness

BufferAssigner::Run()
  assigner::CreateAssignment()
    liveness = BufferLiveness::Run(module, std::move(hlo_ordering)  //class BufferLiveness
      liveness = new BufferLiveness()
      liveness->Analyze()
        points_to_analysis_ = TuplePointsToAnalysis::Run()
          logical_buffer_analysis = LogicalBufferAnalysis::Run()
            analysis = new LogicalBufferAnalysis()
            analysis.Analyze()
            return analysis
          analysis = new TuplePointsToAnalysis(logical_buffer_analysis)
          analysis.Analyze()
          return analysis
        maybe_live_out_buffers_ = points_to_analysis_->GetPointsToSet(root).CreateFlattenedSet()
      return liveness
    assignment(new BufferAssignment(module, std::move(liveness)

-1- Memory优化入口
-3-16- 获取BufferLiveness, 对比-16-, 获取到的BufferLiveness会用于支撑Buffer优化决策,用OOP的方式优雅的实现pipeline
-8- 获取LogicalBufferAnalysis
-9- 根据HloInstruction的Shape, 构造一组相应的LogicalBuffer, Shape本身是一个树结构,表示一个HloInstruction的输出形状, 一个LogicalBuffer对应一个Shape的一个节点(Subshape), 表示一块逻辑buffer, 用pair<HloInstruction*, ShapeIndex>来索引一个LogicalBuffer。
-10- 作为Liveness分析的原材料-LogicalBuffer已经构造完毕, 存储下来, 并返回analysis, 准备进入下一阶段, 进行TuplePointsToAnalysis.
-11- 用存储有所有LogicalBuffer信息的LogicalBufferAnalysis实例构造TuplePointsToAnalysis实例, Tuple, 用来描述Buffer树状结构的方式, E = {%1, %2, {%3, %4}} 表示这样一个树:深为3, 深1的节点有一个, 树的根, 深为2的节点有3个, %1, %2, 没体现名字的节点暂且叫”X”, 深为3的节点有2个, %3, %4, 这两个节点是上一层中”X”的子节点. PointsTo, “指向”, 在前面的例子中, root的”PointsTo”就是%1, %2和”X”, “X”的”PointsTo”就是%3和%4, 所以TuplePointsToAnalysis就是分析整个计算图中的LogicalBuffer依赖关系并存储在PointsToSet中,后面会详细分析
-12- 执行分析逻辑.
-13- 对计算图中的LogicalBuffer的依赖关系分析完毕, 存储下来, 并返回analysis, 准备进入下一阶段, 进行BufferLiveness
-14- 用TuplePointsToAnalysis实例获取root的alias_buffer, 也就是潜在的需要传出的HloModule的Buffer, 存储在maybe_live_out_buffers_中.
-15- 返回liveness实例, TuplePointsToAnalysis实例会存储LogicalBuffer的依赖关系, 但BufferLiveness并不会存储每一个LogicalBuffer的”liveness”, 而是基于TuplePointsToAnalysis封装了一组判断特定LogicalBuffer的函数.
-16- 将BufferLiveness实例传入构造LogicalBuffer与BufferAllocation映射关系的BufferAssignment实例.

上面是整个BufferLiveness分析的主干流程, 下面介绍一些关键细节

LogicalBufferAnalysis

class LogicalBufferAnalysis : public DfsHloVisitorWithDefault

LogicalBufferAnalysis “is a” DfsHloVistorWithDefault, 后者是HLO层提供的以visitor模式(没错, 就是设计模式中的Visitor模式)遍历HloModule/HloComputation/HloInstruction的接口. 是XLA处理HLO优化的基础工具, 同时对于二次开发而言, 也是理解代码以及Debug的趁手工具. 对于一个”DfsHloVisitorWithDefault”, 我们主要关心就是其各种”Handle”的实现, 同时, 也关心LogicalBufferAnalysis作为子类更加specific的部分.

//tensorflow/compiler/xla/service/logical_buffer_analysis.h
Status DefaultAction(HloInstruction* hlo_instruction) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleDomain(HloInstruction* domain) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;

-2- 缺省Handle, 遍历HloInstruction的每一个Subshape, 调用NewLogicalBuffer()为其构造LogicalBuffer, 每个LogicalBuffer都有LogicalBufferAnalysis内唯一的id, 由于此时一个Compilation Graph内也只有一个LogicalBufferAnalysis实例, 也就是说这个id是HloModule内唯一的. 构造的LogcialBuffer会统一存储在'logical_buffers_''output_buffers_'中, 二者的主要区别是索引方式不同, logical_buffers_以unique_ptr的方式存储原始的LogicalBuffer实例, 后者以std::pair<HloInstruction*, ShapeIndex>为Key, 进行Hash查找, 可快速进行LogicalBuffer检索。如果查看了ShapeIndex的实现细节,可以看到其Index并不是一个单独的”id”而是一个”absl::InlinedVector<int64, 2> indices_”, 这里indices_中的每一个int64都代表HloInstruction对应的Shape多叉树的一个节点。 ShapeIndex只负责Index的部分, 至于”Leaf”是什么,由用户决定,这里,LogicalBufferAnalysis的Leaf就是LogicalBuffer,使用hash建立ShapeIndex和LogicalBuffer的关系。


-3:11- 显然, 9个”non default”的Handle都是因为缺省DefaultAction的处理方式不适用于自己所负责HloOpcode, 里面的注释很清晰, 嗯, 如果已经理解TuplePointsToAnalysis的话. LogicalBufferAnalysis和TuplePointsToAnalysis分属不同文件主要是由于使用了Visitor的处理方式, 从单个HloInstruction或者单个Buffer的角度, 设计上一脉相承, 通常要一起分析. 这里, 从理解代码的角度, 可以重点关注以下几个Handle:
-3- HandleTuple, NewLogicalBuffer(tuple, /*index=*/{});, 分配以”{}”为索引的LogcialBuffer, “{}”, 看起来并不符合”ShapeIndex”的含义, 看起来没有任何”形状”信息, 实际上, 这里只需了解: 以”{}”为索引是因为kTuple只需要一个”top-level buffer”, 至于什么是”top-level buffer”呢? 我在TuplePointsToAnalysis中详细分析.
-5:6- HandleBitcast 和 HandleDomain, 不构造任何LogicalBuffer.
-7- HandleCopy, 单纯构造一个新LogicalBuffer, 同样, 一个”top-level buffer”
-8- HandleRecvDone, 两个LogicalBuffer, 位于{}的top-level buffer用于以buffer alias的形式存储接收的数据, 位于{1}的buffer用于存储交互token
-9- HandleSend, 3个LogicalBuffer, 位于{}的top-level buffer用于以buffer alias的形式存储即将发送的数据, 位于{1}的buffer用于存储交互context, 位于{2}的buffer`用于存储交互token

了解了Visitor的部分, 现在来看下Specific的部分.

Status LogicalBufferAnalysis::Analyze();
  // We filter out fusion computations, and get to them through fusion
  // instructions. This is because it's possible to have orphaned (unreachable)
  // fusion computations, and we don't want to try to assign buffers to those.
  std::vector<HloInstruction*> fusion_instructions;
  for (auto* computation : module_->MakeNonfusionComputations()):
    computation->Accept(this);
    for (auto* instruction : computation->instructions()):
      if (instruction->opcode() != HloOpcode::kFusion):
        continue;
      GatherFusionInstructions(instruction, &fusion_instructions);
  for (auto* instruction : fusion_instructions):
    instruction->fused_expression_root()->Accept(this);

上面面这个函数是执行遍历的入口, 其中对kFusion进行了filter out的处理, 注释写的很清楚, 对计算图进行”剪枝”, 不给 unreachable fusion computations分配LogicalBuffer, 避免浪费内存, 这样的处理在XLA很多代码中都有遇到. 这里有一个经典问题, TF runtime前期进行图处理的时候, 已经将unreachable的节点去掉了, 为什么XLA这里还要处理一遍呢? 理解的关键点在Op->HloInstruction的映射, XLA用少数几个kOpcode就组合成了几百个Op类型, 虽然TF前期已经将unreachable op去掉了, 但XLA将Op组成的计算图转化为HloInstruction组成的计算图之后, 又会出现新的unreachable节点, 这里的节点不再是Op, 而是粒度更小HloInstruction. 但即便粒度不同, TF传统引擎和XLA执行的很多图优化算法是类似的, 而二者并没有在更高的层次上进行图优化算法的抽象, 导致了这种功能重复实现. 有些同学认为从OOP的角度, 这里存在优化空间, 我的理解是恰恰相反, 这里不但不需要进一步抽象引入依赖, XLA最好将自己完全从TF引擎解耦, 以外部插件的形式接入到TF核心层. 大家感兴趣的可以给我留言讨论.

TuplePointsToAnalysis

理解这个Analysis的主要难点在于理解”TuplePointsTo“的含义, 理解了含义, 其内部的设计逻辑就一目了然. 对于一个HloInstruction而言, “PointsTo“是表达的内容: 该HloInstruction所依赖的Buffer, “Tuple“是表达形式: Buffer和Buffer间的拓扑关系用"{..., {...}}"的形式表达.

源码中, 下面这段注释有很好的解释:

  // For the subshape at the given index (where index is defined as in
  // ShapeUtil::GetSubshape) this method returns the set of HLO instructions
  // which may produce the tuple subshape at that index. For example, given:
  //
  // %tuple1 = tuple(...)
  // %tuple2 = tuple(...)
  // %select = select(%tuple1, %tuple2)
  // %nested_tuple = tuple(%select, %tuple1)
  //
  // These are the values for tuple_sources() for the PointsToSet of
  // %nested_tuple:
  //
  // tuple_sources({}) = {%nested_tuple}
  // tuple_sources({0}) = {%tuple1, %tuple2}
  // tuple_sources({1}) = {%tuple1}
  //
  // tuple_sources() at the index of an array shape (not a tuple) returns the
  // empty set. The instructions in the set returned by tuple_sources
  // necessarily are either Tuple instructions, constants, or parameters.

转化成图示, 左图就是上例图示,其中蓝色箭头就是”PointsTo”的含义, 表示的buffer索引, 所有的箭头和它对应的buffer们就是kTuple的”PointsToSet”, 个人认为, 这里使用kTupleSelect并不适合作为例子, 因为它本身不是大多数情况使用的DefaultAction, 而是和kTuple一样对输入的Buffer进行alias的操作, 所以简单画了右图, 希望能表达清楚。

介绍了”TuplePointsTo”的含义,接下来从源码角度分析其实现细节。

// A class describing the source(s) of the Buffer(s) contained in the output of
// a particular HLO instruction. The structure of PointsToSet mirrors the
// structure of the instruction's shape, which may be an arbitrary tree (eg, a
// nested tuple). Each node in this tree corresponds to a single buffer in the
// instruction's output and contains the set of Buffers which might define
// the corresponding buffer.
class PointsToSet {
 private:
  struct Elem {
    BufferList buffers;
    SourceSet tuple_sources;
  };
  ShapeTree<Elem> tree_;
};
// DFS visitor that performs tuple points-to analysis. This analysis determines
// the potential sources of each buffer in each instruction's output.
class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
  // Information kept per instruction
  struct PerInstruction {
    std::unique_ptr<PointsToSet> points_to_set;
    // Empircally, ~92% of instructions have 1
    // instruction_defined_buffer, and 99% have 0 or 1
    BufferDefinitionVector instruction_defined_buffers;
  };

  // The logical buffers for this module.
  const std::unique_ptr<LogicalBufferAnalysis> logical_buffer_analysis_;

  // A map from instruction->unique_id() to
  absl::flat_hash_map<int, std::unique_ptr<PerInstruction>> per_instruction_;

  // A map from LogicalBuffer->id() to alias information about that logical
  // buffer
  std::vector<BufferAliasVector> logical_buffer_aliases_;

同样, TuplePointsToAnalysis也是DfsHloVisitorWithDefault, 而对于所有的Visitor, 我们都主要关注其Handle实现, 细心的同学可能发现了, 这里的Handle名与LogicalBufferAnalysis实现的一模一样, 事实上, TuplePointsToAnalysis的几个Handle实现其功能依赖于LogicalBufferAnalysis中对应的Handle的实现, 所以在看这部分代码时, 如果TuplePointsToAnalysis的某个Handle不理解, 可以去LogcialBufferAnalysis中的同名Handle看一下.

Status DefaultAction(HloInstruction* hlo_instruction) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleDomain(HloInstruction* domain) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleAddDependency(HloInstruction* add_dependency) override;

-1- DefaultAction, LogicalBufferAnalysis的DefaultAction()根据Shape构造LogicalBuffer实例, 这里的DefaultAction()取其结果, 将每一个LogicalBuffer都存储在PointToSet的buffers中, 此外, 对于TupleShape的HloInstruction, 还将该HloInstruction作为其上述buffers的”source”. 至于什么样XlaOp会将Shape设置为Tuple, 有这些: Infeed, InfeedWitHToken, RecvFromHost, RecvWithToken, SendToHost, SendWithToken, 都是用于大量传输数据的XlaOp. 构造PointToSet实例是作为TuplePointsToAnalysis作为DfsHloVisitorWithDefault的主要工作.
-2- 在LogicalBufferAnalysis的HandleTuple()基础上, 构造PointsToSet实例, 注意到, 1. Handle内获取LogcialBuffer的索引与LogicalBufferAnalysis中存入LogcialBuffer的索引一致, 2. 对于Tuple来说, 它的PointsToSet就是每一个operand的PointsToSet集合, 包括buffers和buffers对应的source. 这一点后面在BufferLiveness中也会用到.
-3:10- 在LogicalBufferAnalysis相应Handle的基础上, 构造相应的PointToSet实例

共性部分分析完, 同样, 下面来看下TuplePointsToSet的Specific部分. 如前文所述, Analyze()方法是入口.

TuplePointsToAnalysis::Analyze()
  std::vector<HloInstruction*> fusion_instructions
  for computation : module_->MakeNonfusionComputations()):
    computation->Accept(this)
    PopulateDefinedBuffersAndAliases(computation->instructions())
      for instruction : instructions:
        PerInstruction* pi = PerInst(instruction)
        GatherBuffersDefinedByInstruction(instruction, &pi->instruction_defined_buffers))
      points_to_set = GetPointsToSet(instruction)
      points_to_set.ForEachElement(
        [this, &instruction](const ShapeIndex& index,const PointsToSet::BufferList& pointed_to_buffers) {
          for (const LogicalBuffer* buffer : pointed_to_buffers):
            logical_buffer_aliases_[buffer->id()].emplace_back(instruction, index)
        })
    for instruction : computation->instructions():
      if instruction->opcode() == HloOpcode::kFusion:
        GatherFusionInstructions(instruction, &fusion_instructions)
  // Run points-to analysis on fusion instructions in 'computation'.
  for instruction : fusion_instructions:
    instruction->fused_expression_root()->Accept(this))
    PopulateDefinedBuffersAndAliases(instruction->fused_instructions())

-3- 这里和LogicalBufferAnalysis类似, 同样进行的”剪枝”处理.
-5- 除了通过Accept()执行Handle, TuplePointsToAnalysis还会对Handle处理的结果通过PopulateDefinedBuffersAndAlias()进一步处理, 形成两个存储结构: 'TuplePointsToSet::PerInstruction::instruction_defined_buffers' 'TuplePointsToAnalysis::logical_buffer_aliases_', 一个instruction_defined_buffers是通过GatherBuffersDefinedByInstruction()构造的, 里面依然是一个对所有SubShape遍历操作, 经过之前的过程, 不是可以轻松的获得一个HloInstruction对应的LogicalBuffer, 为什么这里还要遍历一次呢? 如果您有这样的疑问, 建议看下HandleTuple()等TuplePointsToAnaysis的”定制”Handle, 对于 kTuple 等几个HloOpcode, 它的PointsTo的source可不是本身, 而是其operand的PointsTo的source, 显然, 这类HloInstruction没有自己的”defined buffers”, 而后续进行allocation assignment的前提条件就是”buffer is defined by instruction”, 这里即是挑出会被assigned的”buffer defined by instruction”, 此外,代码的实现还顺便做了合法性检查: kTuple etc类的HloInstruction的所有buffers的source, 都不会是自己. 至于logical_buffer_aliases_, 可以理解为一个反向检索, Alias, 别名, 也是这个意思, 可以方便用户根据LogicalBuffer->id, 快速的索引到对应的HloInstruction和ShapeIndex.

BufferLiveness

经过LogicalBufferAnalysis构造原材料(LogicalBuffer), TuplePointsToAnalysis构造原材料之间的关系(PointsToSet), 大部分苦活已经完成, BufferLiveness进过封装就可以对外提供判断一个LogcialBuffer的Liveness的接口, 主要以下几个:

  // Returns true if the live range of the buffer containing the output of 'a'
  // may overlap with the live range of the buffer of 'b'. If instruction 'a'
  // interferes with instruction 'b' then they cannot share the same buffer.
  bool MayInterfere(const LogicalBuffer& a, const LogicalBuffer& b) const;

  // Returns true if the buffer for the given instruction may be live out of the
  // module. That is, the instruction's buffer may be included in the output of
  // the entry computation.
  bool MaybeLiveOut(const LogicalBuffer& buffer) const;

  // Returns the complete set of buffers that may be live out of the module.
  const PointsToSet::BufferSet& maybe_live_out_buffers() const;
  // Returns the underlying points-to analysis used for this liveness analysis.
  const TuplePointsToAnalysis& points_to_analysis() const;

这里还是关注几个小细节.

Status BufferLiveness::Analyze() {
  points_to_analysis_ = TuplePointsToAnalysis::Run(module_)
  for computation : module_->computations():
    if computation->IsFusionComputation():
      continue;
    // Gather all instructions whose buffers might alias other instructions into
    // the set aliased_buffers_.  This includes those contained as a tuple
    // element in other instruction's output.
    for instruction : computation->instructions():
      for LogicalBuffer* aliased_buffer : points_to_analysis_->GetPointsToSet(instruction).CreateFlattenedSet():
        if (aliased_buffer->instruction() != instruction):
          aliased_buffers_.insert(aliased_buffer);
    if computation == module_->entry_computation():
      HloInstruction* root = computation->root_instruction();
      maybe_live_out_buffers_ = points_to_analysis_->GetPointsToSet(root).CreateFlattenedSet();

-11- 如果一个HloInstruction->PointsToSet中->LogicalBuffer的source不是该HloInstruction, 就表示该HloInstruction是类Tuple的, 没有实质的LogicalBuffer, 也不是任何LogicalBuffer的source, 将其存储在alias_buffers_
-15- live_out_buffer就是生命周期长于entry_computation(或者理解为HloModule)的Buffer, 就是承载整个entry_computation输出数据的Buffer, root作为整个程序的总入口和总出口, 其PointsToSet的LogicalBuffers就是整个程序总输出Buffer不难理解, 但这里又为什么叫”maybe”? 代码的注释是, 后续进行执行流优化的时候, 会对hlo_ordering调整, 那么就有可能root不再作为最后的HloInstruction, 所以叫”maybe”. 但根据我对XLA的理解, 没有想到哪些优化会导致root被提前, 也许只是这个文件的程序员给自己留个后路, 毕竟万一有呢, 当然, 更可能是我忽略了一些细节, 理解不深(TODO).

下面两个是BufferLiveness和核心方法, 在后续的优化被频繁调用, 逻辑简单, 就不展开

  // Returns true if the live range of the buffer containing the output of 'a'
  // may overlap with the live range of the buffer of 'b'. If instruction 'a'
  // interferes with instruction 'b' then they cannot share the same buffer.
  bool MayInterfere(const LogicalBuffer& a, const LogicalBuffer& b) const;
  // Returns true if the buffer for the given instruction may be live out of the
  // module. That is, the instruction's buffer may be included in the output of
  // the entry computation.
  bool MaybeLiveOut(const LogicalBuffer& buffer) const;




Related:
Tensorflow XLA HLO — BufferLiveness
Tensorflow XLA Service 详解 II
Tensorflow XLA Service 详解 I
Tensorflow XLA Client | HloModuleProto 详解
Tensorflow XlaOpKernel | tf2xla 机制详解
Tensorflow JIT 技术详解
Tensorflow JIT/XLA UML
Tensorflow Forums

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.