Tensorflow内部添加op有3种方式:
1. 不修改Tensorflow源码, 在运行环境里通过Tensorflow包中的.h和.so直接扩展, 无需Tensorflow编译环境, 二次开发首选
2. 修改Tensorflow源码, 利用源码中的User-supplied Op接口直接将新Op添加到Tensorflow安装包中, 关注Op实现, 无需调试Tensorflow的Bazel编译框架, 但个人认为比较鸡肋
3. 修改Tensorflow源码, 以Standard Op的方式将Op添加在Tensorflow包中, 将新Op提交到社区需要这种方式. 需要实现Op, 同时需要修改Tensorflow的Bazel编译框架, 由于Tensorflow使用了自动生成代码的技术, 后者调试比较困难,谨慎使用
本文主要介绍这3种添加TfOp的方式, 最后, 简要介绍添加XlaOp的方法. 对于需要修改源码的方法2和方法3, 在Tensorflow内部均需要修改python接口和tfop两层, 而每一种又有两种方式:User-supplied Op 和Standard Op, 所以一共是2*2 = 4 种, 简化起见, 本文分别用用user-supplied op实现了XZeroOutOp(输入清零后输出), 用Standard Op 实现了XDoNothingOp(输入透传给输出), 在实际操作中, tensorflow内部这两层使用自动生成代码技术进行了解耦, python层和tf层使用不同的方式完全没问题.
Customized Op
Horovod就是使用Cutomized Op这种方式的典型框架, 基于Tensorflow 接口, MPI和NCCL接口, Horovod在不修改Tensorflow源码的前提下实现了Ring-allreduce 的Op. 使用这种方式Op实现的方式与Tensorflow内部的Standard Op并无区别, 关键是编译的时候要include相关的头文件:”/usr/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/framework/”, 以及链接相应的库:”/usr/lib/python2.7/site-packages/tensorflow/libtensorflow_framework.so”
此外需要注意, Tensorflow框架会自动将驼峰命名转化为下划线命名, 即名为”HorovodAllreduce”的Op, 在生成的so中, 对应”horovod_allreduce”, 感兴趣的同学可以参考Tensorflow自动生成py的代码.
User-supplied Op
TfOp层的修改
将自定义op放在tensorflow/core/user_ops/下, 使用这种方式, 用户就不必关心怎么去修改BUILD文件, 编译系统会自动为user_ops下的op生成相应的C API和python代码, 上层只需调用”gen_user_ops.”即可使用. 但这种方式与标准TfOp有些不同, 需要将Op的声明和实现都放在一起, 否则就需要自行修改BUILD文件完成编译
//cat tensorflow/core/user_ops/x_zero_out_op.cc //op #include "tensorflow/core/framework/op.h" REGISTER_OP("XZeroOut") .Input("in: int32") .Output("out: int32"); //kernel #include "tensorflow/core/framework/op_kernel.h" using namespace tensorflow; class XZeroOutOp:public OpKernel{ public: explicit XZeroOutOp(OpKernelConstruction * context):OpKernel(context){} void Compute(OpKernelContext* context) override{ const Tensor& input_tensor = context->input(0); auto flat_input_tensor =input_tensor.flat<int32>(); Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto flat_output_tensor = output_tensor->template flat<int32>(); for(int i = 0; i < flat_input_tensor.size(); ++i){ flat_output_tensor(i) = 0; } } }; REGISTER_KERNEL_BUILDER(Name("XZeroOut").Device(DEVICE_GPU), XZeroOutOp)
python接口层修改
直接在源码 “tensorflow/python/user_ops/user_ops.py” 中导出API, 下述代码中的`gen_*`是编译系统对运行时内op的python wrapper, 这里, 由于tfop使用了User-supplied op 这种方式添加了名为x_zero_out的op, 所以在python API 层看到的是`gen_user_ops.x_zero_out`
@tf_export(v1=['user_ops.x_zero_out']) def zero_out(): return _gen_user_ops.x_zero_out()
测试Op
最后, 我们可以添加单元测试代码(tensorflow/python/ops/user_ops_test.py)来验证op的正确性
import tensorflow as tf from tensorflow.python.platform import test as test_lib class ZeroOutTest(test_lib.TestCase): def testZeroOut(self): with self.test_session(): result = tf.user_ops.zero_out([5, 4, 3, 2, 1]) self.assertAllEqual(result.eval(), [0, 0, 0, 0, 0])
tf_py_test( name = "user_ops_test", size = "small", srcs = ["ops/user_ops_test.py"], )
$ bazel test //tensorflow/python:user_ops_test ... Target //tensorflow/python:user_ops_test up-to-date: bazel-bin/tensorflow/python/user_ops_test INFO: Elapsed time: 15.679s, Critical Path: 10.96s INFO: 2 processes: 2 local. INFO: Build completed successfully, 2 total actions //tensorflow/python:user_ops_test PASSED in 0.8s INFO: Build completed successfully, 2 total actions
Standard Op
TfOp层修改
// cat tensorflow/core/ops/x_do_nothing_op.cc //op #include "tensorflow/core/framework/op.h" REGISTER_OP("XDoNothing") .Input("in: float32") .Output("out: float32");
// cat tensorflow/core/kernels/x_do_nothing_op.cc //kernel #include "tensorflow/core/framework/op_kernel.h" using namespace tensorflow; class XDoNothingOp:public OpKernel{ public: explicit XDoNothingOp(OpKernelConstruction * context):OpKernel(context){} void Compute(OpKernelContext* context) override{ const Tensor& input_tensor = context->input(0); VLOG(0)<<__func__; context->set_output(0, input_tensor); } }; REGISTER_KERNEL_BUILDER(Name("XDoNothing").Device(DEVICE_GPU), XDoNothingOp)
修改 tensorflow/core/kernels/BUILD:
tf_kernel_library( name = "x_do_nothing_op", prefix = "x_do_nothing_op", deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", ], )
修改 tensorflow/core/BUILD:
tf_gen_op_libs( is_external = False, op_lib_names = [ "x_do_nothing_op", ], deps = [ ":lib", ":protos_all_cc", ], ) cc_library( name = "ops", visibility = ["//visibility:public"], deps = [ ":x_do_nothing_op_op_lib", ] + if_mkl([ ":mkl_array_ops_op_lib", ":mkl_nn_ops_op_lib", ]) + tf_additional_cloud_op_deps(), alwayslink = 1, ) cc_library( name = "all_kernels_impl", visibility = [":__subpackages__"], deps = [ "//tensorflow/core/kernels:x_do_nothing_op", ]), )
创建tensorflow/core/api_def/base_api/api_def_XDoNothing.pbtxt
op { graph_op_name: "XDoNothing" summary: "Does nothing." }
python接口层修改
创建 tensorflow/python/ops/x_do_nothing_op.py
# cat tensorflow/python/ops/x_do_nothing_op.py from tensorflow.python.ops import gen_x_do_nothing_op from tensorflow.python.util.tf_export import tf_export @tf_export('xx_do_nothing') def xxx_zero_out(a): return gen_x_do_nothing_op.x_do_nothing()
修改 tensorflow/python/ops/standard_ops.py:
from tensorflow.python.ops.x_do_nothing_op import *
修改tensorflow/python/BUILD:
py_library( name = "no_contrib", ... deps = [ ":x_do_nothing_op", ], ) tf_gen_op_wrapper_private_py( name = "x_do_thing_op_gen", visibility = [ "//learning/brain/python/ops:__pkg__", ], deps = [ "//tensorflow/core:x_do_thing_op_op_lib", ], ) py_library( name = "x_do_nothing_op", srcs = [ "ops/x_do_nothing_op.py", ], srcs_version = "PY2AND3", deps = [ ":x_do_nothing_op_gen", ], )
测试Op
执行bazel build //tensorflow:root_init_gen 可以生成python接口, 可以用来快速检查添加的op是否正确. 如果执行无误, 会在”bazel-bin/tensorflow/__init__.py” 中看到”from tensorflow.python import xxx_do_nothing as xx_do_nothing”, 如果安装到系统中, 就是在”/usr/lib/python2.7/site-packages/tensorflow/__init__.py”中的”from tensorflow.python import xxx_do_nothing as xx_do_nothing”. 在本例中, 我们用下述脚本对op进行测试:
import tensorflow as tf sample = tf.random_normal(shape = [1000, 1000], mean=-1, stddev=4) result = tf.xx_do_nothing(sample) sess = tf.Session() with sess: sess.run(result)
测试结果:
2019-12-28 14:24:09.288989: I tensorflow/core/kernels/x_do_nothing_op.cc:9] Compute
添加XlaOp
创建tensorflow/compiler/tf2xla/kernels/x_do_nothing_op.cc
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { class XDoNothingOp : public XlaOpKernel { public: explicit XDoNothingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { VLOG(0)<<__func__; } void Compile(XlaOpKernelContext* ctx) override { VLOG(0)<<"XDoNothingOp"<<" "<<__func__; xla::XlaBuilder* b = ctx->builder(); xla::XlaOp handle = ctx->Input(0); ctx->SetOutput(0, handle); } }; REGISTER_XLA_OP(Name("XDoNothing").CompilationOnly(), XDoNothingOp); } // namespace } // namespace tensorflow
修改tensorflow/compiler/kernels/BUILD:
tf_kernel_library( name = "xla_ops", srcs = [ "x_do_nothing_op.cc", ...