Tensorflow Op机制详解

注册一个tfop分为两部分: OpOpKernel. 其中, Op是tfop的声明部分, 类似于函数的声明, 主要描述Op静态属性. OpKernel是tfop的实现部分, 同样类似于函数的实现, 主要描述OpKernel的具体计算逻辑. tf已经实现的Op相关源码在tensorflow/core/common_runtime/ops/中, OpKernel的部分在tensorflow/core/common_runtime/kernels/. 二者通过注册时使用的名字联系到一起.

接口形式

下面是一个注册Op的例子, 使用接口"REGISTER_OP()"‘注册一个Op,包括输入输出, Op属性等内容, 注册的Op最终会转换为OpDef, 即用protobuf格式存储的静态数据, 所以, 这里用到的属性也按照pb的格式以KV的形式编写. REGISTER_OP的实现在”tensorflow/core/framework/op.h OpDefBuilderWrapper” 中, 众多可设值中,"Attr"是一个允许自定义的值, 比如XLA引擎就根据自身需求提供了"XlaCompile", 如果一个Op将该值设置为true, 就会强制XLA引擎将其编译. 当然, 也可以设置一些无用的值, 就像函数声明里有一个并没有实际使用的参数, 除了浪费存储空间没有其他用途.

//tensorflow/core/common_runtime/nccl_ops.h
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"

namespace tensorflow {

using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;

REGISTER_OP("NcclAllReduce")
    .Input("input: T")
    .Output("data: T")
    .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
    .Attr("T: {half, float, float64, int32, int64}")
    .Attr("num_devices: int")
    .Attr("shared_name: string")
    .Attr("XlaCompile: bool=true")
    .SetIsStateful()
    .SetShapeFn(shape_inference::UnchangedShape);
}

注册原理

在内部原理上, 上述代码可以进一步分解为两部分: 构造一个Op + 注册一个Op, 这部分代码主要在 “tensorflow/core/framework/ op.h(.cc) op_def_builder.h(.cc) ” 中

 //op.h
 #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
 #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
 #define REGISTER_OP_UNIQ(ctr, name)                                          \
   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr    \
       TF_ATTRIBUTE_UNUSED =                                                  \
           ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
               name)>(name)

 //macros.h
 #define TF_ATTRIBUTE_UNUSED __attribute__((unused))

上述代码的关键在-5-8-, 等号的右边是构造了一个OpDefBuilderWrapper对象并设置好相应的属性, 等号的左侧是构造一个static的, 利用"__COUNTER__"唯一标识的, OpDefBuidlerReceiver对象, 这个对象构造的会对带注册Op做相应的检查, 如果检查通过将其转换成OpDef对象, 并进一步构造为一个OpRegistrationData对象并存储在下面这个registry_中,

  mutable std::unordered_map<string, const OpRegistrationData*> registry_

顺便说一句, 这里面OpDefBuilderWrapper的实现有个小trick: 我们知道构造函数返回的是一个对象, 那么就可以通过”.member_fcn()”的形式直接调用其成员函数, 这里, OpDefBuilderWrapper还进一步, 其相应的成员函数均是返回”*this”类型, 可以实现“链式”调用, 形如”OpDefBuilderWrapper instance.mem_fn0().mem_fn1()….”, 也正是这种设计, 才有了注册Op时清爽的表达.

class OpDefBuilderWrapper<true> {
 public:
  explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
  OpDefBuilderWrapper<true>& Attr(string spec) {
    builder_.Attr(std::move(spec));
    return *this;
  }
  OpDefBuilderWrapper<true>& Input(string spec) {
    builder_.Input(std::move(spec));
    return
  //...

调试方法

tf在 OpRegistry 中实现了很多Op操作方法, 可用于调试的主要有下面几个:

  Status LookUp(const string& op_type_name,
                const OpRegistrationData** op_reg_data) const override;
  
  // A singleton available at startup.
  static OpRegistry* Global();

  // Get all registered ops.
  void GetRegisteredOps(std::vector<OpDef>* op_defs);

  // Get all `OpRegistrationData`s.
  void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);




Related:
Tensorflow XLA HLO I — BufferLiveness
Tensorflow XLA Service 详解 II
Tensorflow XLA Service 详解 I
Tensorflow XLA Client | HloModuleProto 详解
Tensorflow XlaOpKernel | tf2xla 机制详解
Tensorflow JIT 技术详解
Tensorflow JIT/XLA UML
Tensorflow OpKernel机制详解
Tensorflow Op机制详解
Tensorflow Optimization机制详解
Tensorflow 图计算引擎概述

One comment on “Tensorflow Op机制详解

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Discover more from sketch2sky

Subscribe now to keep reading and get access to the full archive.

Continue reading