Tensorflow XLA Client | HloModuleProto 详解

compiler/aot/ 以AOT的方式将tf2xla/接入TF引擎
compiler/jit/以JIT的方式将tf2xla/接入TF引擎, 核心是9个优化器和3个tfop,其中XlaCompileOp调用tf2xla的“编译”入口完成功能封装,XlaRunOp调用xla/client完成“运行”功能。
compiler/tf2xla/对上提供xla_compiler.cc:XlaCompiler::CompileFunction()供jit:compile_fn()使用将cluster转化为XlaComputation。核心是利用xla/client提供的接口,实现对XlaOpKernel的“Symbolic Execution”功能。每个XlaOpKernel子类均做的以下工作: 从XlaOpKernelContext中取出XlaExpression或XlaOp, 调用xla/client/xla_buidler.h提供的方法完成计算, 将计算结果的XlaOp存入XlaKernelContext.**
compiler/xla/client/ 对上提供xla_builder.cc:Builder等供CompileFunction()使用,将Graph由Op表达转化为HloModuleProto:HloComputationProto:HloInstructionProto表达并保存在XlaComputation中。
对上提供local_client.cc:LocalClient::Compile(),作为编译入口供jit:BuildExecutable()使用,将已经得到的XlaComputation交给service并进一步编译为二进制。
对上提供local_client.cc:LocalExecutable::Run(),作为运行入口供jit/kernels/xla_ops.cc:XlaRunOp使用,通过Key找到相应的二进制交给service层处理
compiler/xla/service/ 对上提供local_service.cc:LocalService::BuildExecutable()供LocalClient::Compile()使用实现真正的编译,承接XlaComputation封装的HloProto, 将其转化为HloModule:HloComputation:HloInstruction表达, 对其进行优化之后, 使用LLVM后端将其编译为相应Executable后端的二进制代码
对上提供executable.cc:Executable::ExecuteOnStream()供LocalExecutable::Run()使用实现真正的执行二进制。

compiler/xla/client 向上为tf2xla/下的XlaOpKernel的实现提供支撑, 将上层请求转换为HloModule交给下层xla/service优化并编译.

接口上, client做上表中的三件事 , 实际上, 只有XlaOpKernel->HloProto在Client完成, 而对于另外两个编译和执行,都是在service中完成的,或者说,Client本质上是Service的Facade + Proxy,对繁琐的构造HloModuleProto的过程进行了封装,降低了XLA Service与上层模块的耦合

 client.h:Client    Client基类, 用于多态实现
 client_library.h:ClientLibarary  使用单例构造client_library对象,  用于检索/构造所需的Client实例
 lib/   同builder一起实现”Symbolic Execution”
 local_client.h:LocalClient, LocalExecutable JIT 使用的LocalClient定义, 是Service相关方法的Proxy
 xla_builder.h:XlaBuilder    提供接口用tf2xla使用实现”Symbolic Execution”, 是其中XlaBuilder::Build()是构造client构造HloModuleProto的核心方法。对于需要多个步骤完成初始化的类, 我们会使用Builder模式, 这就是个例子
 xla_computation.h:XlaComputationXlaComputation对象是对HloModuleProto的封装, 用于进一步二进制编译

职责1: 构造HloModuleProto

编译二进制之前首先要完成Graph表达方式的映射: Client之前的tf2xla的Graph由使用Op表达, Client之后的Service的Graph使用HloInstruction表达, Client负责完成这种转化, 具体地, 就是将Op转化为HloProto格式, 再交给Service解析为Hlo格式, 其中的HloProto就是封装在XlaComputation中. 所以, 这个过程可以看做是”编译”的准备工作. 在这个过程中, Graph, Cluster, XlaComputation, HloModuleProto, HloModule是一一对应的,都是”Program”的概念, 这一点从源码的ProgramShape等变量名中都可以体现

按照OOP的方式理解OOP的代码, 在说明Client是如果构造HloModuleProto之前, 来介绍几个概念, HloModuleProto-HloComputationProto- HloInstructionProto, 这3个概念是XLA Service提供给上层调用者(Here, XLA Client)的用于构造HloModule-HloComputation-HloInstruction的protobuf形式的”原材料”, Module-Computation-Instruction, 在概念上可以对应程序-函数-语句

  1. HloProto中的函数调用都是以”半inline“方式实现, inline的部分是指类似”a = add(x, y);”这样的语句, 会将add函数体整个copy到该语句的上下文, 在执行的时候调用之, 在HloProto中, 被调用的XlaComputation也会被copy到调用语句的上下文, “上下文”就是调用语句HloInstructionProto所处的XlaComputation(对应一个HloModuleProto)的XlaBuilder实例. “半”的部分是指被copy的XlaComputation并不是直接在调用语句处展开, 而是在调用语句的HloModuleProto中存储被调用的XlaComputation的Id, 执行的时候直接跳转执行.
  2. 类似多文件编程, 构造的过程可以有多个HloModuleProto, 但由于上一条所述, 最终所有的被调用函数均会被copy到根HloModuleProto的上下文, 最终只有也只需要根HloModuleProto就足以以描述整个程序的逻辑
  3. 一个C/C++程序有main作为入口, 对应到HloModuleProto就是entry_computation, 一个函数的第一条语句, 对应到HloComputationProto就是root_id

下面我们简要分析下代码, 既然Hlo是描述一个程序, 那么就需要解决一个基本的问题: 描述出函数->语句->函数->语句…这样的递归结构, OOP的Composite模式, 或者内核的kset-kobj结构, 都是解决递归的好方式, Hlo模块的实现也比较简单, 可以认为是kset-kobj的结构的C++版本:在HloComputationProto里记录HloInstructionProto,在HloInstructionProto里记录被调用的另一个HloCProto实例Id

tensorflow::XlaCompilationCache::Compile(kernel/out_compilation_result, executable/out_executable)
  return tensorflow::XlaCompilationCache::CompileImpl(out_compilation_result, out_executable)
    if (!entry->compiled):
      e->std::function<Status(XlaCompiler* compiler,...(entry->compilation_result)
        tensorflow::XlaCompiler::CompileFunction(out_compilation_result/result, fn_name_attrs)
          tensorflow::XlaCompiler::CompileGraph()
            ExecuteGraph(context/xla_context)
              tensorflow::GraphCompiler::Compile()
                for node in topo_sorted_nodes:
                  tensorflow::XlaCompilationDevice::Compute(op_context)
                    tensorflow::XlaOpKernel::Compute()
                      XlaOpKernelContext xla_context(context)
                      Compile(&xla_context);
                          //xla/client/xla_builder.cc
                          //CrossReplicaSum
                          b = CreateSubBuilder("sum");
                            return sub_builder = absl::make_unique<XlaBuilder>(computation_name);
                          Add(b->Parameter(),b->Parameter();
                          auto computation = b->Build();
                          AddCalledComputation(computation, &instr)
                            for e : computation.computations():
                              HloComputationProto new_computation(e);
                              computation_id = GetNextId();
                              remapped_ids[new_computation.id()] = computation_id
                              for instruction : *new_computation.mutable_instructions():
                                int64 instruction_id = GetNextId();
                                remapped_ids[instruction.id()] = instruction_id;
                                new_computation.set_root_id(remapped_ids.at(new_computation.root_id()))
                                imported_computations.push_back(new_computation);
                            instr->add_called_computation_ids(remapped_ids.at(computation.proto().entry_computation_id()))
                            for (auto& imported_computation : imported_computations):
                              for (auto& instruction : *imported_computation.mutable_instructions()):
                                for operand_id : *instruction.mutable_operand_ids():
                                  operand_id = remapped_ids.at(operand_id);
                                for control_predecessor_id : *instruction.mutable_control_predecessor_ids():
                                  control_predecessor_id = remapped_ids.at(control_predecessor_id);
                                for alled_computation_id : *instruction.mutable_called_computation_ids():
                                  called_computation_id = remapped_ids.at(called_computation_id);
                              computation_id = imported_computation.id();
                              embedded_.insert({computation_id, std::move(imported_computation)});
                          AddInstruction(HloInstructionProto instr)
                            instr.set_id(handle);
                            instr.set_opcode(HloOpcodeString(opcode));
                            instr.set_name()
                            handle_to_index_[handle] = instructions_.size();
                            instructions_.push_back(std::move(instr));
                            XlaOp op(handle, this);
                            return op
            BuildComputation()
              computation_status = builder->Build();
                Build(instructions_.back().id()...)
                    for instruction in instructions_:
                      entry.add_instructions()->Swap(&instruction);
                  XlaComputation computation(entry.id());
                  HloModuleProto* module = computation.mutable_proto();
                  module->set_name(entry.name());
                  module->set_id(entry.id());
                  module->set_entry_computation_name(entry.name());
                  module->set_entry_computation_id(entry.id());
                  *module->mutable_host_program_shape() = entry.program_shape();
                  for e in embedded_:
                    module->add_computations()->Swap(e.second)
                  module->add_computations()->Swap(&entry);
              computation = computation_status.ConsumeValueOrDie();
      e->tensorflow::XlaCompilationCache::BuildExecutable()

-1- jit/xla_compilation_cache.cc
-2- jit的下边界, XlaCompiler::CompileFunction(tf2xla/xla_compiler.cc), 进入tf2xla
-3- 对于cache中没有的entry, 真正的去编译
-4- compile_fn()
-5- tf2xla/xla_compiler.cc
-9- 每一个真正执行编译的node
-14- tf2xla/kernel的边界, 用于构造XlaComputation, 个人认为是Client的核心文件, 里面实现的XlaOp是client对上(Here, tfxla)提供的接口
-15- CrossReplicaSum恰好是10个需要调用其他XlaComputation的XlaOp(Call, Conditional, CrossReplicaSum, Map, Reducem ReduceWindow, Scatter, SelectAndScatterWithGeneralPadding, Sort, While)之一, 简单又全面, 这里作为例子
-16- XlaBuilder里构造XlaBuilder, 类似编译a.c的时候遇到了对b.c的依赖, 将b.c 构造好在作为依赖项copy到a.c
-18- XlaComputation的函数体
-19- XlaBuilder::Build(), 参考
-51-, Build()的作用就是将该XlaBuilder下存储在instructions_的HloInstructionProto和embedded_的HloComputation构造HloModuleProto, 并封装到XlaComputation中并返回
-20- CrossReplicaSum通过AddCalledComputation添加所调用的XlaComputation, 这个函数整体上做2件事: 1. 将被调用的XlaComputation deep copy到当前的XlaBuilder的instruction_和embedded_, 2. 在调用点的instruction处记录该XlaComputation的entry_computation_id
-41- 将准备好的Instruction添加到当前的XlaBuilder, 注意, CrossReplicaSum通过CreateSubBuilder创造的XlaBuilder实例是局部变量, 调用方需要的信息已经完全copy到了调用方的XlaBuilder实例, 所以这个局部变量随函数结束一起销毁没有任何问题, 其他9个类似的XlaOp同理, 这也就是前文例子中使用”inline”的原因, 如果理解了inline, 也就不难理解, 这里同样有inline的问题: 如果一个XlaComputation被多次调用, 那么就会copy多次, 产生同inline一样的代码膨胀
-47- XlaOp本质就是一个在当前XlaBuilder下instructions_的handle
-49- 前面已经将所有的XlaOpKernel都转换为HloInstructionProto存储在instructions_中, 所有依赖的”函数”XlaComputations也都存储在了embedded_中, 是时候构造最终的”main”函数的XlaComputation了
-64-构造完成

职责2: 编译cubin

e->tensorflow::XlaCompilationCache::BuildExecutable(entry->compilation_result, &entry->executable)-->通向xla/local_client 及 local_service
  xla::ExecutableBuildOptions build_options;
  compile_result = xla::LocalClient::Compile()  
    executable = xla::LocalService::CompileExecutable() 

职责3: 执行cubin

tensorflow::XlaRunOp::Compute()
  run_result = xla::LocalExecutable::Run() //friend class LocalClient
    return executable_->ExecuteOnStreamWrapper() 




Related:
Tensorflow XLA Service Buffer优化详解
Tensorflow XLA Service 详解 II
Tensorflow XLA Service 详解 I
Tensorflow XLA Client | HloModuleProto 详解
Tensorflow XlaOpKernel | tf2xla 机制详解
Tensorflow JIT 技术详解
Tensorflow JIT/XLA UML
Tensorflow OpKernel机制详解
Tensorflow Op机制详解
Tensorflow Optimization机制详解
Tensorflow 图计算引擎概述

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Discover more from sketch2sky

Subscribe now to keep reading and get access to the full archive.

Continue reading