未必孤独网 > 请问自动微分法(Automatic differentiation)是如何用C++实现的?

请问自动微分法(Automatic differentiation)是如何用C++实现的?

【李瞬生的回答(242票)】:

实现 AD 有两种方式,函数重载代码生成。两种方式的原理都一样,链式法则

不难想象,任何计算都可以由第1步到第k步的序列形式,其中第 i 步计算的输入,在之前的 i-1 步中已经计算(例如编译器生成的汇编指令序列)。因此,任何计算都可以看作形式如下图左侧的复合函数。微积分中的链式法则告诉我们,符合函数的导数可写作下图右侧的形式(假设每一步都可导)。请注意偏导数全导数的区别。

图一

自动微分的第一个难点就在这,微积分中的链式法则。大家在课堂上学的链式法则的示例通常只有两到三个函数,而自动微分面对的计算,有无数个函数。许多人不习惯在这样大的规模上应用链式法则。不过一旦习惯,就会发现自动微分的原理十分简单。

如果上述内容过于抽象,请参看下面这个例子以后再看一遍。

图二

在上图中,顶部方程是二维旋转。输入

有三个变量,旋转角度及二维坐标。输出

有两个变量,即旋转后的二维坐标。上图列表第二列是该计算的序列形式,第一列是每一步对应的表达式,第三列是对应的链式法则(请对比图一)。

太繁琐了?看不出个所以然?不用担心。

如果把计算的序列形式及其导数计算每个步骤的依赖关系表示成图

图三

不难发现,两张图是等价的。也就是说,计算序列形式的每一步都与其导数计算的步骤有一一对应的关系。源程序怎么算,其导数就可以怎么算(从顺序上来说)。不难发现,两张图是等价的。也就是说,计算序列形式的每一步都与其导数计算的步骤有一一对应的关系。源程序怎么算,其导数就可以怎么算(从顺序上来说)。

以上便是自动微分的基本原理。下面我们来谈实现。

如图二,图三,我们有两种方式来考虑自动微分的实现。

  1. 用户提供图二第二列序列形式的源代码,按顺序生成第三列的微分计算。此种方法的特点是,读一行源代码,生成一行微分计算,因此可以动态生成。

    若源代码这一行在做乘法,那么就依据乘法法则生成该步的微分计算。若源代码这一行是三角函数 cos(x),那么它对应的微分计算就是 -sin(x),以此类推。每一步计算的偏导数都根据链式法则组合,得出该步骤的全导数。

    该种方法的常见手段是函数重载。优点是简单直接,缺点是动态生成成本较高。

  2. 用户提供源代码,在编译时生成图三左侧的程序结构图,并生成图三右侧对应的微分程序。

    该种方法的常见手段是编译时的代码生成(比如用 flex-bison 做词法、语法分析)。优点是静态生成效率高,一次生成,多次使用。缺点是编译原理有门槛,非计算机专业望而却步。

力求简单,下面给出以函数重载实现的简单示例 SAD - Simple Automatic Differentiation。请对照图二中的列表阅读。

main 函数中对应图二表中第二列,2D 旋转的计算。ADV 里重载的 +, -, *, sin, cos 不仅完成本来的计算,还负责图二表中第三列导数的计算。

main 函数中 x, y, z 的序号都与图二中对应。

#include cmath#include vector#include iostreamnamespace SAD // Simple Automatic Differentiation{class ADV{public:ADV(double v = 0, double d = 0);// overloaded unary and binary operatorsADV operator + (const ADV x) const;ADV operator - (const ADV x) const;ADV operator * (const ADV x) const;friend ADV sin(const ADV x);friend ADV cos(const ADV x);double val; // value of the variabledouble dval; // derivative of the variable};ADV::ADV(double v, double d) : val(v), dval(d) {}ADV ADV::operator+(const ADV x) const{ADV y;y.val = val + x.val;y.dval = dval + x.dval;return y;}ADV ADV::operator-(const ADV x) const{ADV y;y.val = val - x.val;y.dval = dval - x.dval; // sum rulereturn y;}ADV ADV::operator*(const ADV x) const{ADV y;y.val = val*x.val;y.dval = x.val*dval + val*x.dval; // product rulereturn y;}ADV sin(const ADV x){ADV y;y.val = std::sin(x.val);y.dval = std::cos(x.val)*x.dval; // chain rulereturn y;}ADV cos(const ADV x){ADV y;y.val = std::cos(x.val);y.dval = -std::sin(x.val)*x.dval; // chain rulereturn y;}}int main(){using namespace SAD;using namespace std;static const double PI = 3.1415926;vectorADV x;x.emplace_back(PI, 1); // x = [PI, 2, 1]x.emplace_back(2, 0);x.emplace_back(1, 0);ADV y1 = cos(x[0]);ADV y2 = sin(x[0]);ADV y3 = x[1] * y1;ADV y4 = x[2] * y2;ADV y5 = x[1] * y2;ADV y6 = x[2] * y1;ADV z1 = y3 + y4;ADV z2 = y6 - y5;cout "x = [" x[0].val ", " x[1].val ", " x[2].val "]" endl;cout "z = [" z1.val ", " z2.val "]" endl;cout "[dz1/dx0, dz2/dx0] = [" z1.dval "," z2.dval "]" endl;}运行结果:

矢量 [2,1] 被旋转 180° , 变为 [-2,-1]。关于角度的导数为 [-1,2]。

自动微分的经典教材是该题目的奠基人 Griewank 著的 Evaluating Derivatives (Society for Industrial and Applied Mathematics)

该书囊括了自动微分的所有方面,比如本文未介绍的 reverse mode, sparse Jacobian, Hessian 等。

如果不求全面,一本更通俗更面向代码实现的书是 The Art of Differentiating Computer Programs (Society for Industrial and Applied Mathematics)

最后,自动微分是算导数的最优方法,比符号计算、有限微分更快更精确。

自动微分已经广泛应用在优化领域,包括人工神经网络的训练算法 back-propagation。

要解连续优化或非线性方程,自动微分是不二的选择。

【前进四先生的回答(13票)】:

请问自动微分法(Automatic differentiation)是如何用C++实现的? - 李瞬生的回答关于自动微分法理论的部分讲的已经很全面了,最后给出的示例代码也很棒,不过是runtime期间求值。我看题目提到了compile time求值,我补充一下这部分。

仅依靠模板元编程,是不能实现在编译期间求值的,因为模板参数只能是整数,对于求导来说不够用。

但是在c++11标准下,我们有了新的编译期间求值工具,constexpr表达式。只要在函数的定义前加上constepxr,编译器就可以在编译期间对函数求值,并将求值结果视为一个编译期间的常量。但是对于constexpr函数,有额外的要求

  • 函数体只能有一句,就是return语句。
  • 这唯一的语句不能使用非常量表达式的函数(标准库的math函数全跪)、全局数据,且必须是一个常量表达式。
constexpr函数的返回值不再限定为整数,可以是浮点数,还可以是简单的类。这意味着constexpr元编程比模板元编程威力要强大的多。

我按照这个思路将 @李瞬生的代码翻译到了constexpr元编程的形式,如下。

#include iostreamnamespace AD {templatetypename Tconstexpr T pow_helper(T x, unsigned n) { return (n == 0) ? 1 : (n % 2 == 0) ? pow_helper(x*x, n/2) : x * pow_helper(x*x, (n-1)/2);}// x^ntemplatetypename Tconstexpr T pow(T x, int n) { return (n == 0) ? 1 : (n 0) ? pow_helper(x, n) : 1 / pow_helper(x, -n);}// n!templatetypename Tconstexpr T factorial(T n) { return (n == 1) ? 1: n*factorial(n-1);}// x^n/n!templatetypename Tconstexpr T xn_n(T x, unsigned n) { return pow(x, n) / factorial(n);}// calculation sin(x), N controls Taylor series ordertemplatetypename Tconstexpr double sin(T x, unsigned N = 12) { return (N % 2 == 0 || N == 1) ? ((N ==1 ) ? x : sin(x, N-1)) : (((N-1)/2) % 2 == 0) ? (xn_n(x, N) + sin(x, N-2)) : (-xn_n(x, N) + sin(x, N-2));}// calculation cos(x), N controls Taylor series ordertemplatetypename Tconstexpr double cos(T x, unsigned N = 12) { return (N % 2 == 1 || N == 0) ? ((N == 0) ? 1 : cos(x, N-1)) : ((N/2) % 2 == 0) ? (xn_n(x, N) + cos(x, N-2)) : (-xn_n(x, N) + cos(x, N-2));}// vartemplatetypename Tstruct Var { constexpr Var(T in): y(in), dy(1.0) {} T y; T dy;};// consttemplatetypename Tstruct Konst { constexpr Konst(T in): y(in), dy(0.0) {} T y; T dy;};// cos//FIXME: 当N大于12时,N!会超出unsigned的最大值,导致精度下降templatetypename Tstruct Cos { constexpr Cos(T in, unsigned N = 12): y(cos(in.y, N)), dy(-sin(in.y, N)*in.dy) {} double y; double dy;};// sin//FIXME: 当N大于12时,N!会超出unsigned的最大值,导致精度下降templatetypename Tstruct Sin { constexpr Sin(T in, unsigned N = 12): y(sin(in.y, N)), dy(cos(in.y, N)*in.dy) {} double y; double dy;};// addtemplatetypename T1, typename T2struct Add { constexpr Add(T1 in1, T2 in2): y(in1.y + in2.y), dy(in1.dy + in2.dy) {} double y; double dy;};// subtemplatetypename T1, typename T2struct Sub { constexpr Sub(T1 in1, T2 in2): y(in1.y - in2.y), dy(in1.dy - in2.dy) {} double y; double dy;};// mutiplytemplatetypename T1, typename T2struct Mul { constexpr Mul(T1 in1, T2 in2): y(in1.y * in2.y), dy(in1.dy*in2.y + in1.y*in2.dy) {} double y; double dy;};} // end namespace ADint main() { using namespace AD; constexpr Vardouble x0 {3.1415926}; constexpr Konstdouble x1 {2.0}; constexpr Konstdouble x2 {1.0}; constexpr Cosdecltype(x0) y1 {x0}; constexpr Sindecltype(x0) y2 {x0}; constexpr Muldecltype(x1), decltype(y1) y3 {x1, y1}; constexpr Muldecltype(x2), decltype(y2) y4 {x2, y2}; constexpr Muldecltype(x1), decltype(y2) y5 {x1, y2}; constexpr Muldecltype(x2), decltype(y1) y6 {x2, y1}; // z1 = x1*cos(x0) + x2*sin(x0) constexpr Adddecltype(y3), decltype(y4) z1 {y3, y4}; // z2 = x2*cos(x0) - x1*sin(x0) constexpr Subdecltype(y6), decltype(y5) z2 {y6, y5}; static_assert(z1.dy 0 , "Got z1 diff in compile time!"); static_assert(z2.dy 0 , "Got z2 diff in compile time!"); std::cout "x = [" x0.y ", " x1.y ", " x2.y "]" std::endl; std::cout "z = [" z1.y ", " z2.y "]" std::endl; std::cout "[dz1/dx0, dz2/dx0] = [" z1.dy "," z2.dy "]" std::endl; return 0;}

将文件存储为constexpr_ad.cpp

g++ -std=c++11 constexpr_ad.cpp -o ad./ad输出如下:

结果是正确的,但是没有 结果是正确的,但是没有 @李瞬生的结果精确。这可以理解,毕竟是自己写的三角函数。

那如何证明这是在编译期间计算出来的呢,用c++11的编译期断言static_assert即可。static_assert要求表达式的结果必须是在编译期间就可以求值,不然就会报错。我们这里没报错,就说明这个值已经计算出来了。

如果不放心,可以将ad的rodata段dump出来看看,看看导数到底有没有存在常量表中,如下:

最后八个字节:631a6803 80000040正好是2.00024的little endian二进制表示。.rodata上面那一堆常量都是中间计算结果,大家有兴趣可以去最后八个字节:631a6803 80000040正好是2.00024的little endian二进制表示。.rodata上面那一堆常量都是中间计算结果,大家有兴趣可以去Floating Point to Hex Converter换算成double类型试试。

ps1: 写完这篇我就想,会不会有人已经实现编译期间的数学计算库,搜了一下,果然有GitHub - Morwenn/static_math: Compile time mathematic functions for C++14.

ps2:在c++14的标准中,constexpr函数的要求比c++11宽松的多。函数体可以使用for循环语句了,看来标准委员会对constexpr metaprogramming也很喜欢啊。

【徐普的回答(3票)】:

上面的回答都是纸上谈兵.

auto diff 用的很广的,eigen有,不过deprecate了。ceres是力推auto diff的,它是用基于dual number的方式实现的.

一个函数以模板的形式写好,把dual number做类型参数带进去就可以得到一阶导. 理论基础是 non standard analysis.

代码实现到ceres里面去找,很简单的

【科学匠人的回答(0票)】:

github上搜啊

codipack

adept

都是简短精悍的表达式模版实现

号称性能和人工代码相当

那也能和代码转换工具相当

不要问我原理,我只是用户

【Kache的回答(0票)】:

ceres里面就有

新闻聚焦
热门推荐
  • 低俗靡乱!喜宴竟充斥惊艳脱衣舞表演

    中新网12月7日电 据台湾《联合报》报道,桃园县内喜宴、庙会、社团、晚会充斥钢管、清凉秀、脱衣舞,县议员舒翠玲以自己参加的场合为证,当场看见辣妹和客人磨蹭,甚至指导单位是“桃园县政府”、“公所”的活动也如......

    01-13 来源:未知

    分享
  • 《我是特种兵之霹雳火》崔华盾扮演者张进个人资料及照

    《我是特种兵之霹雳火》崔华盾扮演者 本篇电视资讯由未必孤独网(www.vbgudu.com)独家整理,如有转载请注明出处。 曾经同是“特警小虎队”一员的李飞和张进这次将重新在《霹雳火》中集结,并且再度并肩作战。 由李......

    01-13 来源:未知

    分享
  • 郎永淳老婆吴萍患肿瘤赴美疗养 郎永淳近况

    郎永淳温馨全家福 央视新闻主播郎永淳虽然在电视上天天与观众见面,因播报新闻成了公众人物,并拥有了很多的粉丝。但生活中的郎永淳却十分很低调,不仅从未谈及过自己的私生活,就连他的另一半及孩子也未被曝光过。......

    01-13 来源:未知

    分享
  • 《我是特种兵之霹雳火》王星扮演者李飞个人资料及照片

    《我是特种兵之霹雳火》王星扮演者李飞 本篇电视资讯由未必孤独网(www.vbgudu.com)独家整理,如有转载请注明出处。 《我是特种兵之霹雳火》作为刘猛导演特种兵系列的第四部作品,自筹划以来就备受网友关注。承继着......

    01-13 来源:未知

    分享
  • 梦鸽:为孩子做什么都不为过 李案会造成世界影响

    梦鸽(资料图) 李某某等涉嫌强奸案从2月份发酵至今,持续半年,热度不减。作为被告李某某的监护人,梦鸽放下红色明星、部队歌唱家的尊严,发布声明、反诉、上访,走进长枪短炮的包围圈,代替独子站在第一线。 为了......

    01-13 来源:未知

    分享
  • 雷!彪悍美女竟在大街上做超不雅动作

    ......

    01-13 来源:未知

    分享
  • 孙俪微博拍卖老公邓超的爱裤,邓超与孙俪感情好不好

    今天我们来盘点一下娱乐圈的模范夫妻。孙俪和邓超是娱乐圈有名的模范夫妻,两人相爱至今都没有穿过其他的绯闻,而在邓超走向逗比之路的过程中,娘娘孙俪也开始受到影响,近日邓超在网上晒了一张与孙俪的另类合影,网......

    01-12 来源:

    分享
  • 巩俐与孙红雷谈过恋爱吗?巩俐孙红雷主演的电影是哪部

    从绯闻女友巩俐、左小青,到王骏迪,孙红雷绯闻伴随走红。在《窈窕绅士》发布会上,孙红雷大晒幸福,并直言,“我现在还不会和女友公开亮相,以免被人说我在炒作。”被问及是否有意结婚,他说,“谈婚论嫁对我来说不......

    01-12 来源:

    分享
  • 曝盛一伦喜欢骂人成瘾,盛一伦同性恋是真的吗?

    子妃升职记不仅火啦张天爱,也让男主盛一伦踏进拉娱乐圈。盛一伦被曝骂人成瘾 骂人聊天记录图片,近日,盛一伦将东家乐漾影视诉至法院,索片酬1051.5万元,朝阳法院已受理此案。12月12日,盛一伦发长文回应解约风波称......

    01-12 来源:

    分享
  • 北京学生卡坐地铁打折吗?北京现在有几条地铁?

    北京的物价使出拉名的贵,许多北漂为啦省钱想尽办法。近日,在北京部分地铁站周边,出现贩卖“”的卡贩子,100元就能办一张大,还送学生证。新京报记者探访发现,从卡贩子手中购得的,能顺利充值并可享受2.5折优惠。......

    01-12 来源:

    分享
返回列表