Tensorflow JIT 技术详解

compiler/aot/ 以AOT的方式将tf2xla/接入TF引擎
compiler/jit/以JIT的方式将tf2xla/接入TF引擎, 核心是9个优化器和3个tfop,其中XlaCompileOp调用tf2xla的“编译”入口完成功能封装,XlaRunOp调用xla/client完成“运行”功能。
compiler/tf2xla/对上提供xla_compiler.cc:XlaCompiler::CompileFunction()供jit:compile_fn()使用将cluster转化为XlaComputation。核心是利用xla/client提供的接口,实现对XlaOpKernel的“Symbolic Execution”功能。每个XlaOpKernel子类均做的以下工作: **从XlaOpKernelContext中取出XlaExpression或XlaOp, 调用xla/client/xla_buidler.h提供的方法完成计算, 将计算结果的XlaOp存入XlaKernelContext.**
compiler/xla/client/ 对上提供xla_builder.cc:Builder等供CompileFunction()使用,将Graph由Op表达转化为HloModuleProto:HloComputationProto:HloInstructionProto表达并保存在XlaComputation中。
对上提供local_client.cc:LocalClient::Compile(),作为编译入口供jit:BuildExecutable()使用,将已经得到的XlaComputation交给service并进一步编译为二进制。
对上提供local_client.cc:LocalExecutable::Run(),作为运行入口供jit/kernels/xla_ops.cc:XlaRunOp使用,通过Key找到相应的二进制交给service层处理
compiler/xla/service/ 对上提供local_service.cc:LocalService::BuildExecutable()供LocalClient::Compile()使用实现真正的编译,承接XlaComputation封装的HloProto, 将其转化为HloModule:HloComputation:HloInstruction表达, 对其进行优化之后, 使用LLVM后端将其编译为相应Executable后端的二进制代码
对上提供executable.cc:Executable::ExecuteOnStream()供LocalExecutable::Run()使用实现真正的执行二进制。

XLA基于编译技术将静态子图转换为二进制进而实现在某些场景下的加速性能,以BERT为例,P40单卡每batch性能从850ms提升到了700ms

JIT 是目前TF中两种XLA应用方式之一, 借助TF对计算图的先优化再执行的机制, JIT使用9个优化器+3个基于XLA模块实现的Tfop将Just In Time技术接入TF图计算引擎, 使之具备了在某些场景下的加速图计算的能力. 

9个Optimization如下图所示:

其中, MarkForCompilationPass, EncapsulateSubgraphsPass 和 BuildXlaOpsPass 最为关键.

3个Op如下, 经过之前的优化, 已经完成了”图->XlaCompileOp + XlaRunOp”的转化, 而JIT的编译过程就在替换了原始Op的XlaCompile中进行. XlaCompileOp编译得到的二进制直接送到紧接其后的XlaRunOp中执行, 由于, XlaCompileOp里有用于存储之前编译结果的Cache, 所以理想情况下(图不变,输入的shape也不变), 只有第一次会真正的编译, 之后的step中由于Cache hit, XlaCompileOp的成本就很低了, 这也是XLA你能够实现加速核心原因. 据此, 在特征识别等输入频繁变动的场景, 由于XlaCompileOp的Cache Miss的概率大大增加, 整体性能就会比常规的TF执行引擎差. 

Optimization

以下面一个极其简单的计算图为例

MarkForCompilationPass

这个Optimazition主要的工作就是进行cluster的划分, 就是每一个Nodedef中添加”_XlaCluster” 属性, 其值为该node被划分到的cluster编号. 即, 从

EncapsulateSubgraphsPass

根据MarkForCompilationPass对Node cluster的标记, 将相同cluster下的所有Node封装到一个名为cluster_XXX的节点中, 比如, _XlaCluster的属性值为cluster_0的所有节点, 在这个阶段就会被封装到名为cluster_0 的节点下.在本例中, 转换为下图:

BuildXlaOpsPass

每一个 IsXlaCompiledKernel 都会被替换为XlaCompile + XlaRun

OpKernel

JIT实现的两个OpKernel的最重要的工作就是维持一个CompilationCache, 在Cache Miss的情况下, XlaCompileOp调用tf2xla/提供的编译接口, 获取cubin的同时将其存入CompilationCache以便后续使用, 获取到的cubin 存入以Tensor KeyT为键的XlaExecutableClosureStore::Global()中, 这是为了符合TF引擎使用OpContext在OpKernel之间传递数据的框架, 这里, Tensor KeyT就是XlaCompileOp的输出, 其后继的XlaRunOp从OpContext中取出KeyT, 并进一步取出XlaCompileOp存入XlaExecutableClosureStore::Global()中的相应二进制, 调用xla/client提供的接口执行之.

下面是笔者曾写的一个简单的TF计算图, 通过配置, 将Matmul和Square划归到一个Cluster, Substract和Add在另一个Cluster, 可以看到, 第1个step还有XlaCompileOp有调用ptx真正的执行编译, 第二次由于Cache Hit, 就不再有真实编译过程了

XlaCompileOp

XlaCompileOp主要做两个工作: 

  1. 调用tf2xla/xla_compiler.cc XlaCompiler::CompileFunction(), 借助tf2xla/kernels/将GraphDef转换为xla/client/xla_computation.h XlaComputation, 
  2. 调用xla/client/local_client.cc的LocalClient::Compile() 将XlaComputation编译为二进制cubin并存储在GpuExecutable中, 最终传递给后继的XlaRunOp

-4- xla/client/local_client.cc的LocalExecutable关联Executable, 后续的编译过程会将的编译的结果存储在Executable子类GpuExecutable::cubin_中, 
-7- XlaCompileOp核心功能, jit/xla_compilation_cache.cc
-10- signature是key, 用signature从cache中获取二进制, 这个entry, 对应一个XlaCompile的二进制结果, 而一个XlaCompile对应的是一个cluster
-11- Cache Miss, 对于cache中没有的entry, 真正的去编译
-8,12-13- 将Graph转换为XlaComputation对象, 其中核心工作就是调用tf2xla/kernel中的XlaOpKernel::Compile()方法, 完成这个映射
-15- 调用xla/client/local_client.cc的LocalClient::Compile()将XlaComputation编译为cubin, 并将结果封装到executable对象中
-17- 将XlaComputation带出
-18- 将含有cubin的executable带出
-19- 将executable存入XlaExecutableClosureStore::Global(), 并用KeyT索引, 等待XlaRunOp使用
-24- 设置XlaCompileOp作为一个OpKernel的输出数据, 包括用于找到cubin的key和编译成功标记

XlaRunOp

-2- 从OpContext中获取到XlaCompileOp存入的Key, 据此找到含有二进制的
-8- 执行xla/client/local_client.cc LocalExecutable::Run()->Executable::ExecuteOnStreamWrapper()((xla/service/executable.cc)->Executable::ExecuteOnStream()(xla/service/executable.cc)->GpuExecutable::ExecuteOnStream()(xla/service/gpu/gpu_executable.cc)调用到XlaCompile存入LocalExecutable中的二进制, .

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.

%d bloggers like this: