Tensorflow OptimizationPassRegistry机制详解

整个流水线大概是:Init Graph –> Grappler进行全图优化 –> 根据Device将Graph拆成 Partition –> OptimizationPassRegistry优化各个Partition –> 图执行. 同其他注册机制一样, OptimizationPassRegistry也使用的registry, registertion以及registerar等概念.

源码中, 使用REGISTER_OPTIMIZATION()注册一个优化器, 具体实现如下

//core/common_runtime/optimization_registry.h
#define REGISTER_OPTIMIZATION(grouping, phase, optimization) \
  REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization)

#define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \
  REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)

#define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)         \
  static ::tensorflow::optimization_registration::OptimizationPassRegistration \
      register_optimization_##ctr(                                             \
          grouping, phase,                                                     \
          ::std::unique_ptr<::tensorflow::GraphOptimizationPass>(              \
              new optimization()),                                             \
          #optimization)

-12-new了一个我们注册的optimization对象并用unique_ptr指向它, 这个unique_ptr就是registry管理的对象, 通过它间接管理相应的optimization. 注册的本质是返回一个静态的, 类型为'OptimazationPassRegistration'的, 名为'register_optimization_##ctr'的对象, 这里使用了C++预编译宏'__COUNTER__'生成唯一变量名

下面是一个`OptimazationPassRegistration`对象的构造过程, 可以看出, 就是将我们构造的'register_optimization_##ctr注册到全局的'global_optimization_registry'.

OptimizationPassRegistration::OptimizationPassRegistration()
  pass->set_name(optimization_pass_name);
  OptimizationPassRegistry::Global()->Register(grouping, phase, std::move(pass))
    //Global()
    static OptimizationPassRegistry* global_optimization_registry = new OptimizationPassRegistry;
    return global_optimization_registry;
    //Register()
    groups_[grouping][phase].push_back(std::move(pass));

-8-使用group+phase的方式对一个'OptimazationPassRegistration'对象进行分组.

对于已经注册的OptimazationPassRegistration', 使用'Rungrouping()执行之

//optimization_registery.cc
Status OptimizationPassRegistry::RunGrouping(Grouping grouping) {
  auto group = groups_.find(grouping);
  if (group != groups_.end()) {
    for (auto& phase : group->second) {
      for (auto& pass : phase.second) {
        Status s = pass->Run(options);
      }
    }
  }
  return Status::OK();
}

-5:9-从小到大遍历指定group下的每一个phase下的每一个pass, 由于在注册的时候指定了group和phase, 所以phase id小的一定比phase id大的先执行, 同一个group同一个phase id, 就按照注册的先后顺序执行了.

在optimization_registry的设计上, group之前并没有先后顺序, 我们可以指定直接执行任何一个group下面的optimization, 但在tensorflow的使用中, 又有一定的顺序. Tensorflow定义了下面的几个group, 分别对应着图执行的从前到后的几个阶段. 在这种使用方法下, group本身也可以作为按照时间先后执行的一个分类.

//optimazition_registry.h
// Groups of passes are run at different points in initialization.
enum Grouping {
  PRE_PLACEMENT,          // after cost model assignment, before placement.
  POST_PLACEMENT,         // after placement.
  POST_REWRITE_FOR_EXEC,  // after re-write using feed/fetch endpoints.
  POST_PARTITIONING,      // after partitioning
};

按照从前到后的执行顺序, tensorflow1.14中的优化器有:

REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0, ParallelConcatRemovePass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,AccumulateNV2RemovePass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0, LowerFunctionalOpsPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0, NcclReplacePass);    
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25, IsolatePlacerInspectionRequiredOpsPass);
//跟XLA相关的9个Opt
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25, IntroduceFloatingPointJitterPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, EncapsulateXlaComputationsPass); 
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27, FunctionalizeControlFlowPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 5, CloneConstantsForBetterClusteringPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20 IncreaseDynamismForAutoJitPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,PartiallyDeclusterPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,EncapsulateSubgraphsPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50,BuildXlaOpsPass);

REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);




Related:
Tensorflow XLA HLO I — BufferLiveness
Tensorflow XLA Service 详解 II
Tensorflow XLA Service 详解 I
Tensorflow XLA Client | HloModuleProto 详解
Tensorflow XlaOpKernel | tf2xla 机制详解
Tensorflow JIT 技术详解
Tensorflow JIT/XLA UML
Tensorflow OpKernel机制详解
Tensorflow Op机制详解
Tensorflow Optimization机制详解
Tensorflow 图计算引擎概述

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.