compiler/aot/ | 以AOT的方式将tf2xla/接入TF引擎 |
compiler/jit/ | 以JIT的方式将tf2xla/接入TF引擎, 核心是9个优化器和3个tfop,其中XlaCompileOp调用tf2xla的“编译”入口完成功能封装,XlaRunOp调用xla/client完成“运行”功能。 |
compiler/tf2xla/ | 对上提供xla_compiler.cc:XlaCompiler::CompileFunction()供jit:compile_fn()使用将cluster转化为XlaComputation。核心是利用xla/client提供的接口,实现对XlaOpKernel的“Symbolic Execution”功能。每个XlaOpKernel子类均做的以下工作: 从XlaOpKernelContext中取出XlaExpression或XlaOp, 调用xla/client/xla_buidler.h提供的方法完成计算, 将计算结果的XlaOp存入XlaKernelContext.** |
compiler/xla/client/ | 对上提供xla_builder.cc:Builder等供CompileFunction()使用,将Graph由Op表达转化为HloModuleProto:HloComputationProto:HloInstructionProto表达并保存在XlaComputation中。 对上提供local_client.cc:LocalClient::Compile(),作为编译入口供jit:BuildExecutable()使用,将已经得到的XlaComputation交给service并进一步编译为二进制。 对上提供local_client.cc:LocalExecutable::Run(),作为运行入口供jit/kernels/xla_ops.cc:XlaRunOp使用,通过Key找到相应的二进制交给service层处理 |
compiler/xla/service/ | 对上提供local_service.cc:LocalService::BuildExecutable()供LocalClient::Compile()使用实现真正的编译,承接XlaComputation封装的HloProto, 将其转化为HloModule:HloComputation:HloInstruction表达, 对其进行优化之后, 使用LLVM后端将其编译为相应Executable后端的二进制代码 对上提供executable.cc:Executable::ExecuteOnStream()供LocalExecutable::Run()使用实现真正的执行二进制。 |
compiler/xla/client 向上为tf2xla/下的XlaOpKernel的实现提供支撑, 将上层请求转换为HloModule交给下层xla/service优化并编译.
接口上, client做上表中的三件事 , 实际上, 只有XlaOpKernel->HloProto在Client完成, 而对于另外两个编译和执行,都是在service中完成的,或者说,Client本质上是Service的Facade + Proxy,对繁琐的构造HloModuleProto的过程进行了封装,降低了XLA Service与上层模块的耦合
client.h:Client | Client基类, 用于多态实现 |
client_library.h:ClientLibarary | 使用单例构造client_library对象, 用于检索/构造所需的Client实例 |
lib/ | 同builder一起实现”Symbolic Execution” |
local_client.h:LocalClient, LocalExecutable | JIT 使用的LocalClient定义, 是Service相关方法的Proxy |
xla_builder.h:XlaBuilder | 提供接口用tf2xla使用实现”Symbolic Execution”, 是其中XlaBuilder::Build()是构造client构造HloModuleProto的核心方法。对于需要多个步骤完成初始化的类, 我们会使用Builder模式, 这就是个例子 |
xla_computation.h:XlaComputation | XlaComputation对象是对HloModuleProto的封装, 用于进一步二进制编译 |
职责1: 构造HloModuleProto
编译二进制之前首先要完成Graph表达方式的映射: Client之前的tf2xla的Graph由使用Op表达, Client之后的Service的Graph使用HloInstruction表达, Client负责完成这种转化, 具体地, 就是将Op转化为HloProto格式, 再交给Service解析为Hlo格式, 其中的HloProto就是封装在XlaComputation中. 所以, 这个过程可以看做是”编译”的准备工作. 在这个过程中, Graph, Cluster, XlaComputation, HloModuleProto, HloModule是一一对应的,都是”Program”的概念, 这一点从源码的ProgramShape等变量名中都可以体现
按照OOP的方式理解OOP的代码, 在说明Client是如果构造HloModuleProto之前, 来介绍几个概念, HloModuleProto-HloComputationProto- HloInstructionProto, 这3个概念是XLA Service提供给上层调用者(Here, XLA Client)的用于构造HloModule-HloComputation-HloInstruction的protobuf形式的”原材料”, Module-Computation-Instruction, 在概念上可以对应程序-函数-语句
- HloProto中的函数调用都是以”半inline“方式实现, inline的部分是指类似”a = add(x, y);”这样的语句, 会将add函数体整个copy到该语句的上下文, 在执行的时候调用之, 在HloProto中, 被调用的XlaComputation也会被copy到调用语句的上下文, “上下文”就是调用语句HloInstructionProto所处的XlaComputation(对应一个HloModuleProto)的XlaBuilder实例. “半”的部分是指被copy的XlaComputation并不是直接在调用语句处展开, 而是在调用语句的HloModuleProto中存储被调用的XlaComputation的Id, 执行的时候直接跳转执行.
- 类似多文件编程, 构造的过程可以有多个HloModuleProto, 但由于上一条所述, 最终所有的被调用函数均会被copy到根HloModuleProto的上下文, 最终只有也只需要根HloModuleProto就足以以描述整个程序的逻辑
- 一个C/C++程序有main作为入口, 对应到HloModuleProto就是entry_computation, 一个函数的第一条语句, 对应到HloComputationProto就是root_id
下面我们简要分析下代码, 既然Hlo是描述一个程序, 那么就需要解决一个基本的问题: 描述出函数->语句->函数->语句…这样的递归结构, OOP的Composite模式, 或者内核的kset-kobj结构, 都是解决递归的好方式, Hlo模块的实现也比较简单, 可以认为是kset-kobj的结构的C++版本:在HloComputationProto里记录HloInstructionProto,在HloInstructionProto里记录被调用的另一个HloCProto实例Id
tensorflow::XlaCompilationCache::Compile(kernel/out_compilation_result, executable/out_executable) return tensorflow::XlaCompilationCache::CompileImpl(out_compilation_result, out_executable) if (!entry->compiled): e->std::function<Status(XlaCompiler* compiler,...(entry->compilation_result) tensorflow::XlaCompiler::CompileFunction(out_compilation_result/result, fn_name_attrs) tensorflow::XlaCompiler::CompileGraph() ExecuteGraph(context/xla_context) tensorflow::GraphCompiler::Compile() for node in topo_sorted_nodes: tensorflow::XlaCompilationDevice::Compute(op_context) tensorflow::XlaOpKernel::Compute() XlaOpKernelContext xla_context(context) Compile(&xla_context); //xla/client/xla_builder.cc //CrossReplicaSum b = CreateSubBuilder("sum"); return sub_builder = absl::make_unique<XlaBuilder>(computation_name); Add(b->Parameter(),b->Parameter(); auto computation = b->Build(); AddCalledComputation(computation, &instr) for e : computation.computations(): HloComputationProto new_computation(e); computation_id = GetNextId(); remapped_ids[new_computation.id()] = computation_id for instruction : *new_computation.mutable_instructions(): int64 instruction_id = GetNextId(); remapped_ids[instruction.id()] = instruction_id; new_computation.set_root_id(remapped_ids.at(new_computation.root_id())) imported_computations.push_back(new_computation); instr->add_called_computation_ids(remapped_ids.at(computation.proto().entry_computation_id())) for (auto& imported_computation : imported_computations): for (auto& instruction : *imported_computation.mutable_instructions()): for operand_id : *instruction.mutable_operand_ids(): operand_id = remapped_ids.at(operand_id); for control_predecessor_id : *instruction.mutable_control_predecessor_ids(): control_predecessor_id = remapped_ids.at(control_predecessor_id); for alled_computation_id : *instruction.mutable_called_computation_ids(): called_computation_id = remapped_ids.at(called_computation_id); computation_id = imported_computation.id(); embedded_.insert({computation_id, std::move(imported_computation)}); AddInstruction(HloInstructionProto instr) instr.set_id(handle); instr.set_opcode(HloOpcodeString(opcode)); instr.set_name() handle_to_index_[handle] = instructions_.size(); instructions_.push_back(std::move(instr)); XlaOp op(handle, this); return op BuildComputation() computation_status = builder->Build(); Build(instructions_.back().id()...) for instruction in instructions_: entry.add_instructions()->Swap(&instruction); XlaComputation computation(entry.id()); HloModuleProto* module = computation.mutable_proto(); module->set_name(entry.name()); module->set_id(entry.id()); module->set_entry_computation_name(entry.name()); module->set_entry_computation_id(entry.id()); *module->mutable_host_program_shape() = entry.program_shape(); for e in embedded_: module->add_computations()->Swap(e.second) module->add_computations()->Swap(&entry); computation = computation_status.ConsumeValueOrDie(); e->tensorflow::XlaCompilationCache::BuildExecutable()
-1- jit/xla_compilation_cache.cc
-2- jit的下边界, XlaCompiler::CompileFunction(tf2xla/xla_compiler.cc), 进入tf2xla
-3- 对于cache中没有的entry, 真正的去编译
-4- compile_fn()
-5- tf2xla/xla_compiler.cc
-9- 每一个真正执行编译的node
-14- tf2xla/kernel的边界, 用于构造XlaComputation, 个人认为是Client的核心文件, 里面实现的XlaOp是client对上(Here, tfxla)提供的接口
-15- CrossReplicaSum恰好是10个需要调用其他XlaComputation的XlaOp(Call, Conditional, CrossReplicaSum, Map, Reducem ReduceWindow, Scatter, SelectAndScatterWithGeneralPadding, Sort, While)之一, 简单又全面, 这里作为例子
-16- XlaBuilder里构造XlaBuilder, 类似编译a.c的时候遇到了对b.c的依赖, 将b.c 构造好在作为依赖项copy到a.c
-18- XlaComputation的函数体
-19- XlaBuilder::Build(), 参考
-51-, Build()的作用就是将该XlaBuilder下存储在instructions_的HloInstructionProto和embedded_的HloComputation构造HloModuleProto, 并封装到XlaComputation中并返回
-20- CrossReplicaSum通过AddCalledComputation添加所调用的XlaComputation, 这个函数整体上做2件事: 1. 将被调用的XlaComputation deep copy到当前的XlaBuilder的instruction_和embedded_, 2. 在调用点的instruction处记录该XlaComputation的entry_computation_id
-41- 将准备好的Instruction添加到当前的XlaBuilder, 注意, CrossReplicaSum通过CreateSubBuilder创造的XlaBuilder实例是局部变量, 调用方需要的信息已经完全copy到了调用方的XlaBuilder实例, 所以这个局部变量随函数结束一起销毁没有任何问题, 其他9个类似的XlaOp同理, 这也就是前文例子中使用”inline”的原因, 如果理解了inline, 也就不难理解, 这里同样有inline的问题: 如果一个XlaComputation被多次调用, 那么就会copy多次, 产生同inline一样的代码膨胀
-47- XlaOp本质就是一个在当前XlaBuilder下instructions_的handle
-49- 前面已经将所有的XlaOpKernel都转换为HloInstructionProto存储在instructions_中, 所有依赖的”函数”XlaComputations也都存储在了embedded_中, 是时候构造最终的”main”函数的XlaComputation了
-64-构造完成
职责2: 编译cubin
e->tensorflow::XlaCompilationCache::BuildExecutable(entry->compilation_result, &entry->executable)-->通向xla/local_client 及 local_service xla::ExecutableBuildOptions build_options; compile_result = xla::LocalClient::Compile() executable = xla::LocalService::CompileExecutable()
职责3: 执行cubin
tensorflow::XlaRunOp::Compute() run_result = xla::LocalExecutable::Run() //friend class LocalClient return executable_->ExecuteOnStreamWrapper()
Related:
Tensorflow XLA Service Buffer优化详解
Tensorflow XLA Service 详解 II
Tensorflow XLA Service 详解 I
Tensorflow XLA Client | HloModuleProto 详解
Tensorflow XlaOpKernel | tf2xla 机制详解
Tensorflow JIT 技术详解
Tensorflow JIT/XLA UML
Tensorflow OpKernel机制详解
Tensorflow Op机制详解
Tensorflow Optimization机制详解
Tensorflow 图计算引擎概述