为Tensorflow 添加新Op

Tensorflow内部添加op有3种方式:
1. 不修改Tensorflow源码, 在运行环境里通过Tensorflow包中的.h和.so直接扩展, 无需Tensorflow编译环境, 二次开发首选
2. 修改Tensorflow源码, 利用源码中的User-supplied Op接口直接将新Op添加到Tensorflow安装包中, 关注Op实现, 无需调试Tensorflow的Bazel编译框架, 但个人认为比较鸡肋
3. 修改Tensorflow源码, 以Standard Op的方式将Op添加在Tensorflow包中, 将新Op提交到社区需要这种方式. 需要实现Op, 同时需要修改Tensorflow的Bazel编译框架, 由于Tensorflow使用了自动生成代码的技术, 后者调试比较困难,谨慎使用

本文主要介绍这3种添加TfOp的方式, 最后, 简要介绍添加XlaOp的方法. 对于需要修改源码的方法2和方法3, 在Tensorflow内部均需要修改python接口和tfop两层, 而每一种又有两种方式:User-supplied Op 和Standard Op, 所以一共是2*2 = 4 种, 简化起见, 本文分别用用user-supplied op实现了XZeroOutOp(输入清零后输出), 用Standard Op 实现了XDoNothingOp(输入透传给输出), 在实际操作中, tensorflow内部这两层使用自动生成代码技术进行了解耦, python层和tf层使用不同的方式完全没问题.

Customized Op

Horovod就是使用Cutomized Op这种方式的典型框架, 基于Tensorflow 接口, MPI和NCCL接口, Horovod在不修改Tensorflow源码的前提下实现了Ring-allreduce 的Op. 使用这种方式Op实现的方式与Tensorflow内部的Standard Op并无区别, 关键是编译的时候要include相关的头文件:”/usr/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/framework/”, 以及链接相应的库:”/usr/lib/python2.7/site-packages/tensorflow/libtensorflow_framework.so”

此外需要注意, Tensorflow框架会自动将驼峰命名转化为下划线命名, 即名为”HorovodAllreduce”的Op, 在生成的so中, 对应”horovod_allreduce”, 感兴趣的同学可以参考Tensorflow自动生成py的代码.

User-supplied Op

TfOp层的修改

将自定义op放在tensorflow/core/user_ops/下, 使用这种方式, 用户就不必关心怎么去修改BUILD文件, 编译系统会自动为user_ops下的op生成相应的C API和python代码, 上层只需调用”gen_user_ops.”即可使用. 但这种方式与标准TfOp有些不同, 需要将Op的声明和实现都放在一起, 否则就需要自行修改BUILD文件完成编译

python接口层修改

直接在源码 “tensorflow/python/user_ops/user_ops.py” 中导出API, 下述代码中的gen_*是编译系统对运行时内op的python wrapper, 这里, 由于tfop使用了User-supplied op 这种方式添加了名为x_zero_out的op, 所以在python API 层看到的是gen_user_ops.x_zero_out

测试Op

最后, 我们可以添加单元测试代码(tensorflow/python/ops/user_ops_test.py)来验证op的正确性

Standard Op

TfOp层修改

修改 tensorflow/core/kernels/BUILD:

修改 tensorflow/core/BUILD:

创建tensorflow/core/api_def/base_api/api_def_XDoNothing.pbtxt

python接口层修改

创建 tensorflow/python/ops/x_do_nothing_op.py

修改 tensorflow/python/ops/standard_ops.py:

修改tensorflow/python/BUILD:

测试Op

执行bazel build //tensorflow:root_init_gen 可以生成python接口, 可以用来快速检查添加的op是否正确. 如果执行无误, 会在”bazel-bin/tensorflow/__init__.py” 中看到”from tensorflow.python import xxx_do_nothing as xx_do_nothing”, 如果安装到系统中, 就是在”/usr/lib/python2.7/site-packages/tensorflow/__init__.py”中的”from tensorflow.python import xxx_do_nothing as xx_do_nothing”. 在本例中, 我们用下述脚本对op进行测试:

测试结果:

添加XlaOp

创建tensorflow/compiler/tf2xla/kernels/x_do_nothing_op.cc

修改tensorflow/compiler/kernels/BUILD:

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.