为Tensorflow 添加新Op

Tensorflow内部添加op有3种方式:
1. 不修改Tensorflow源码, 在运行环境里通过Tensorflow包中的.h和.so直接扩展, 无需Tensorflow编译环境, 二次开发首选
2. 修改Tensorflow源码, 利用源码中的User-supplied Op接口直接将新Op添加到Tensorflow安装包中, 关注Op实现, 无需调试Tensorflow的Bazel编译框架, 但个人认为比较鸡肋
3. 修改Tensorflow源码, 以Standard Op的方式将Op添加在Tensorflow包中, 将新Op提交到社区需要这种方式. 需要实现Op, 同时需要修改Tensorflow的Bazel编译框架, 由于Tensorflow使用了自动生成代码的技术, 后者调试比较困难,谨慎使用

本文主要介绍这3种添加TfOp的方式, 最后, 简要介绍添加XlaOp的方法. 对于需要修改源码的方法2和方法3, 在Tensorflow内部均需要修改python接口和tfop两层, 而每一种又有两种方式:User-supplied Op 和Standard Op, 所以一共是2*2 = 4 种, 简化起见, 本文分别用用user-supplied op实现了XZeroOutOp(输入清零后输出), 用Standard Op 实现了XDoNothingOp(输入透传给输出), 在实际操作中, tensorflow内部这两层使用自动生成代码技术进行了解耦, python层和tf层使用不同的方式完全没问题.

Customized Op

Horovod就是使用Cutomized Op这种方式的典型框架, 基于Tensorflow 接口, MPI和NCCL接口, Horovod在不修改Tensorflow源码的前提下实现了Ring-allreduce 的Op. 使用这种方式Op实现的方式与Tensorflow内部的Standard Op并无区别, 关键是编译的时候要include相关的头文件:”/usr/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/framework/”, 以及链接相应的库:”/usr/lib/python2.7/site-packages/tensorflow/libtensorflow_framework.so”

此外需要注意, Tensorflow框架会自动将驼峰命名转化为下划线命名, 即名为”HorovodAllreduce”的Op, 在生成的so中, 对应”horovod_allreduce”, 感兴趣的同学可以参考Tensorflow自动生成py的代码.

User-supplied Op

TfOp层的修改

将自定义op放在tensorflow/core/user_ops/下, 使用这种方式, 用户就不必关心怎么去修改BUILD文件, 编译系统会自动为user_ops下的op生成相应的C API和python代码, 上层只需调用”gen_user_ops.”即可使用. 但这种方式与标准TfOp有些不同, 需要将Op的声明和实现都放在一起, 否则就需要自行修改BUILD文件完成编译

//cat tensorflow/core/user_ops/x_zero_out_op.cc
//op
#include "tensorflow/core/framework/op.h"
REGISTER_OP("XZeroOut")
  .Input("in: int32")
  .Output("out: int32");

//kernel
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class XZeroOutOp:public OpKernel{
public:
  explicit XZeroOutOp(OpKernelConstruction * context):OpKernel(context){}
  void Compute(OpKernelContext* context) override{
    const Tensor& input_tensor = context->input(0);
    auto flat_input_tensor =input_tensor.flat<int32>();

    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
    auto flat_output_tensor = output_tensor->template flat<int32>();

    for(int i = 0; i < flat_input_tensor.size(); ++i){
      flat_output_tensor(i) = 0;
    }
  }
};
REGISTER_KERNEL_BUILDER(Name("XZeroOut").Device(DEVICE_GPU), XZeroOutOp)

python接口层修改

直接在源码 “tensorflow/python/user_ops/user_ops.py” 中导出API, 下述代码中的`gen_*`是编译系统对运行时内op的python wrapper, 这里, 由于tfop使用了User-supplied op 这种方式添加了名为x_zero_out的op, 所以在python API 层看到的是`gen_user_ops.x_zero_out`

@tf_export(v1=['user_ops.x_zero_out'])
def zero_out():
    return _gen_user_ops.x_zero_out()

测试Op

最后, 我们可以添加单元测试代码(tensorflow/python/ops/user_ops_test.py)来验证op的正确性

import tensorflow as tf
from  tensorflow.python.platform import test as test_lib
class ZeroOutTest(test_lib.TestCase):
  def testZeroOut(self):
    with self.test_session():
      result = tf.user_ops.zero_out([5, 4, 3, 2, 1])
      self.assertAllEqual(result.eval(), [0, 0, 0, 0, 0])
 tf_py_test(
     name = "user_ops_test",
     size = "small",
     srcs = ["ops/user_ops_test.py"],
 )
$ bazel test //tensorflow/python:user_ops_test
...
Target //tensorflow/python:user_ops_test up-to-date:
  bazel-bin/tensorflow/python/user_ops_test
INFO: Elapsed time: 15.679s, Critical Path: 10.96s
INFO: 2 processes: 2 local.
INFO: Build completed successfully, 2 total actions
//tensorflow/python:user_ops_test                                        PASSED in 0.8s

INFO: Build completed successfully, 2 total actions

Standard Op

TfOp层修改

// cat tensorflow/core/ops/x_do_nothing_op.cc
//op
#include "tensorflow/core/framework/op.h"
REGISTER_OP("XDoNothing")
  .Input("in: float32")
  .Output("out: float32");
// cat tensorflow/core/kernels/x_do_nothing_op.cc
//kernel
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class XDoNothingOp:public OpKernel{
public:
  explicit XDoNothingOp(OpKernelConstruction * context):OpKernel(context){}
  void Compute(OpKernelContext* context) override{
    const Tensor& input_tensor = context->input(0);
    VLOG(0)<<__func__;
    context->set_output(0, input_tensor);
  }
};
REGISTER_KERNEL_BUILDER(Name("XDoNothing").Device(DEVICE_GPU), XDoNothingOp)

修改 tensorflow/core/kernels/BUILD:

tf_kernel_library(
  name = "x_do_nothing_op",
  prefix = "x_do_nothing_op",
  deps = [
        "//tensorflow/core:framework",
        "//tensorflow/core:lib",
  ],
)

修改 tensorflow/core/BUILD:

tf_gen_op_libs(
    is_external = False,
    op_lib_names = [
        "x_do_nothing_op",
    ],
    deps = [
        ":lib",
        ":protos_all_cc",
    ],
)
cc_library(
    name = "ops",
    visibility = ["//visibility:public"],
    deps = [
        ":x_do_nothing_op_op_lib",
    ] + if_mkl([
        ":mkl_array_ops_op_lib",
        ":mkl_nn_ops_op_lib",
    ]) + tf_additional_cloud_op_deps(),
    alwayslink = 1,
)
cc_library(
    name = "all_kernels_impl",
    visibility = [":__subpackages__"],
    deps = [
        "//tensorflow/core/kernels:x_do_nothing_op",
    ]),
)

创建tensorflow/core/api_def/base_api/api_def_XDoNothing.pbtxt

op {
  graph_op_name: "XDoNothing"
  summary: "Does nothing."
}

python接口层修改

创建 tensorflow/python/ops/x_do_nothing_op.py

# cat tensorflow/python/ops/x_do_nothing_op.py
from tensorflow.python.ops import gen_x_do_nothing_op
from tensorflow.python.util.tf_export import tf_export

@tf_export('xx_do_nothing')
def xxx_zero_out(a):
    return gen_x_do_nothing_op.x_do_nothing()

修改 tensorflow/python/ops/standard_ops.py:

from tensorflow.python.ops.x_do_nothing_op import *

修改tensorflow/python/BUILD:

py_library(
    name = "no_contrib",
    ...
    deps = [
        ":x_do_nothing_op",
    ],
)
tf_gen_op_wrapper_private_py(
    name = "x_do_thing_op_gen",
    visibility = [
        "//learning/brain/python/ops:__pkg__",
    ],
    deps = [
        "//tensorflow/core:x_do_thing_op_op_lib",
    ],
)
py_library(
    name = "x_do_nothing_op",
    srcs = [
        "ops/x_do_nothing_op.py",
    ],
    srcs_version = "PY2AND3",
    deps = [
        ":x_do_nothing_op_gen",
    ],
)

测试Op

执行bazel build //tensorflow:root_init_gen 可以生成python接口, 可以用来快速检查添加的op是否正确. 如果执行无误, 会在”bazel-bin/tensorflow/__init__.py” 中看到”from tensorflow.python import xxx_do_nothing as xx_do_nothing”, 如果安装到系统中, 就是在”/usr/lib/python2.7/site-packages/tensorflow/__init__.py”中的”from tensorflow.python import xxx_do_nothing as xx_do_nothing”. 在本例中, 我们用下述脚本对op进行测试:

import tensorflow as tf
sample = tf.random_normal(shape = [1000, 1000], mean=-1, stddev=4)
result = tf.xx_do_nothing(sample)
sess = tf.Session()
with sess:
    sess.run(result)

测试结果:

2019-12-28 14:24:09.288989: I tensorflow/core/kernels/x_do_nothing_op.cc:9] Compute

添加XlaOp

创建tensorflow/compiler/tf2xla/kernels/x_do_nothing_op.cc

#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
namespace {

class XDoNothingOp : public XlaOpKernel {
 public:
  explicit XDoNothingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    VLOG(0)<<__func__;
  }

  void Compile(XlaOpKernelContext* ctx) override {
    VLOG(0)<<"XDoNothingOp"<<" "<<__func__;
    xla::XlaBuilder* b = ctx->builder();
    xla::XlaOp handle = ctx->Input(0);
    ctx->SetOutput(0, handle);
  }

};

REGISTER_XLA_OP(Name("XDoNothing").CompilationOnly(), XDoNothingOp);

}  // namespace
}  // namespace tensorflow

修改tensorflow/compiler/kernels/BUILD:

tf_kernel_library(
    name = "xla_ops",
    srcs = [
        "x_do_nothing_op.cc",
    ...

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.