NeuralNetDiffEq.jl:用于常微分方程的神经网络求解器

2017 年 10 月 13 日 | Akshay Sharma

我的 GSoC 2017 项目 是为 Julia 实现一个包,使用神经网络求解常微分方程。该项目的目的是提供一个使用神经网络的附加 DE 求解器,其关键优势在于时间上的并行性,优于其他本质上是迭代的求解器。该项目基于 Lagaris 等人 1997 年 的研究论文,该论文提出了神经网络 (NN) 的函数逼近能力,用于求解微分方程。该项目融合了研究和实施方面,但仍有一些部分有待完善。我选择参与这个项目,因为我对数学和机器学习感兴趣,并且它涉及这两个领域的理念。该包使用 DifferentialEquations.jl 作为求解器接口,使用 Knet.jl 用于 NN 求解器实现。

如何使用神经网络求解微分方程?

该求解器的概念基于 UAT(通用逼近定理),该定理指出,具有至少一个隐藏层的 NN 可以逼近任何连续函数。神经网络被设计为最小化损失函数,该函数定义为 NN 的导数和微分方程的导数之间的差异,这将导致我们的试解收敛到微分方程的实际(解析)解。要了解更多关于 UAT 的信息,请点击这里

项目的科研方面和挑战

我们参考的关于该主题的研究论文相当古老,理解示例和解释非常具有挑战性。使用 NN 用于此目的的研究并不多,因此我们无法从与该主题相关的研究论文中获得太多帮助。最初的任务是阅读和理解求解微分方程背后的数学原理。此外,在计算机上求解微分方程所使用的计算方法与我们在纸上使用的计算方法有很大不同,因此需要花费相当长的时间才能熟悉它们。要使用的 NN 的结构和类型,以便在不影响性能的情况下保留求解器的优势(时间上的并行性),这本身就是一个研究子领域,也是一个挑战。

在实现常微分方程 (ODE) 和 ODE 系统的求解器之后,困难的部分是使 NN 在更长时间域上针对 ODE 系统收敛。由于神经网络涉及许多因素,例如隐藏层宽度、隐藏神经元数量、激活函数、权重等,我依靠自己的机器学习背景以及导师的帮助,对 NN 超参数的大多数可行设置进行了实验,并记录了收敛精度和求解器的性能。使 NN 针对 ODE 系统收敛并不像看起来那样容易,这占用了大部分的实验和调优时间。预测具有较大域的 DE 系统解仍然是一个挑战,需要进一步研究。

实现和工作

实现涉及整合数学和机器学习方面,以构建用于 ODE 的神经网络求解器。 DiffEqBase 库 用作基础,用于扩展算法和求解器接口,而神经网络则是使用 Knet.jl 库 开发的。迄今为止完成的工作可以在 NeuralNetDiffEq.jl github 库 中查看,主要是在 此分支 中。这项工作涉及实现用于 ODE 的神经网络求解器,并根据 NN 预测进行定制插值。

它是如何工作的?

我们根据 NN 输出构建微分方程的试解,该试解也应该满足 DE 边界条件。我们为神经网络定义一个损失函数,该函数是神经网络解关于其输入的导数与 ODE 定义的真实导数之间的差异。这是一个不寻常的损失函数,因为它包含网络本身的梯度。在其他 ML 应用程序中,这种情况几乎从未见过。使用 NN 对该损失函数进行最小化(通过将导数差异等为零),将试解代入原始函数(或 DE 的解)中。神经网络使用 Adam 优化算法根据该损失函数反向传播的梯度来调整其权重。

为了实现时间上的并行化,我们使用 KnetArray(Knet.jl 中使用的数组类型),该类型默认使用 CPU,但也支持 GPU 使用以实现并行化,但需要安装 CUDNN 驱动程序才能访问 GPU 硬件。

示例

下面您可以找到一些关于如何使用我一直在开发的软件包的示例。以下是该软件包正常工作所需的初始导入。

using NeuralNetDiffEq
using Plots; plotly()
using DiffEqBase, ParameterizedFunctions
using DiffEqProblemLibrary, DiffEqDevTools
using Knet

ODE 示例

示例 1

linear = (t,u) -> (1.01*u)
(f::typeof(linear))(::Type{Val{:analytic}},t,u0) = u0*exp(1.01*t)
prob = ODEProblem(linear,1/2,(0.0,1.0))
sol = solve(prob,nnode(10),dt=1/10,iterations=10)
plot(sol,plot_analytic=true)

Plot_ode1

sol(0.232)

1-element Array{Any,1}:
0.625818

示例 2

f = (t,u) -> (t^3 + 2*t + (t^2)*((1+3*(t^2))/(1+t+(t^3))) - u*(t + ((1+3*(t^2))/(1+t+t^3))))
(::typeof(f))(::Type{Val{:analytic}},t,u0) =  u0*exp(-(t^2)/2)/(1+t+t^3) + t^2
prob2 = ODEProblem(f,1.0,(0.0,1.0))
sol2 = solve(prob2,nnode(10),dt=0.1,iterations=200)
 plot(sol,plot_analytic=true)

Plot_ode2

sol(0.47)

1-element Array{Any,1}:
0.803109

示例 3

f2 = (t,u) -> (-u/5 + exp(-t/5).*cos(t))
(::typeof(f2))(::Type{Val{:analytic}},t,u0) =  exp(-t/5)*(u0 + sin(t))
prob3 = ODEProblem(f2,Float32(0.0),(Float32(0.0),Float32(2.0)))
sol3 = solve(prob3,nnode(10),dt=0.2,iterations=1000)
 plot(sol,plot_analytic=true)

Plot_ode3

sol3([0.721])

  1-element Array{Any,1}:
  Any[0.574705]

ODE 系统示例

示例 1 ODE 2D 线性

f_2dlinear = (t,u) -> begin
    du = Array{Any}(length(u))
    for i in 1:length(u)
    du[i] = 1.01*u[i]
  end
    return du
end
(f::typeof(f_2dlinear))(::Type{Val{:analytic}},t,u0) = u0*exp.(1.01*t)
prob_ode_2Dlinear = ODEProblem(f_2dlinear,rand(4,1),(0.0,1.0))
sol1 = solve(prob_ode_2Dlinear,nnode([10,50]),dt=0.1,iterations=100)

(:iteration,100,:loss,0.004670103680503722)
16.494870 seconds (90.08 M allocations: 6.204 GB, 5.82% gc time)
plot(sol1,plot_analytic=true)

Plot_sode1

示例 2 Lotka Volterra

function lotka_volterra(t,u)
  du1 = 1.5 .* u[1] - 1.0 .* u[1].*u[2]
  du2 = -3 .* u[2] + u[1].*u[2]
  [du1,du2]
end

lotka_volterra (generic function with 1 method)
prob_ode_lotkavoltera = ODEProblem(lotka_volterra,Float32[1.0,1.0],(Float32(0.0),Float32(1.0)))
sol2 = solve(prob_ode_lotkavoltera,nnode([10,50]),dt=0.2,iterations=1000)

    (:iteration,100,:loss,0.020173132003438572)
    (:iteration,200,:loss,0.005130137452114811)
    (:iteration,300,:loss,0.004812458584875589)
    (:iteration,400,:loss,0.010083624565714974)
    (:iteration,500,:loss,0.0025328170079611887)
    (:iteration,600,:loss,0.007685579218433846)
    (:iteration,700,:loss,0.005065291031504465)
    (:iteration,800,:loss,0.005326863832044214)
    (:iteration,900,:loss,0.00030436474139241827)
    (:iteration,1000,:loss,0.0034853904995959094)
     22.126081 seconds (99.65 M allocations: 5.923 GB, 5.21% gc time)
plot(sol2)

Plot_sode2

为了说明求解器在当前设置和超参数下不适用于较大域(例如 0-10)lotka_volterra,这里有一张图表

prob_ode_lotkavoltera = ODEProblem(lotka_volterra,Float32[1.0,1.0],(Float32(0.0),Float32(5.0)))
sol3 = solve(prob_ode_lotkavoltera,nnode([10,50]),dt=0.2,iterations=1000)
plot(sol3)

Plot_sode3

但是,真实解应该是振荡的,表明 NN 未正确收敛。要查看更多示例和实验结果,您可以查看我的 Jupyter 笔记本 这里

未来工作

需要更多关于如何优化 NN 以提高速度和收敛性的研究。对于具有较大域的 ODE 系统,当前的神经网络无法收敛。可以使用一种优化算法对 NN 超参数进行一次性优化,以便它能够更好地处理 ODE 系统。我们尝试了许多方法,例如对成本函数进行偏差以优先考虑较早的时间点,但这同样失败了。在 使用 TensorFlow 的替代实现 (TensorFlowDiffEq.jl) 中发现了类似的问题,这表明这可能仅仅是求解方法的问题。

鸣谢

我非常感谢我的 GSoC 导师 Chris Rackauckas 和 Lyndon White 在理解项目的数学和编码部分方面提供的帮助。此外,我要感谢 Julia 社区,感谢他们给我提供贡献的机会,并赞助了我的 JuliaCon 2017 之旅,那次旅行很棒。