Neural network 如何在Flux中训练模型组合?

Neural network 如何在Flux中训练模型组合?,neural-network,julia,Neural Network,Julia,我试图在朱莉娅身上建立一个深度学习模型。我有两个模型m1和m2,它们是神经网络。这是我的密码: using Flux function even_mask(x) s1, s2 = size(x) weight_mask = zeros(s1, s2) weight_mask[2:2:s1,:] = ones(Int(s1/2), s2) return weight_mask end function odd_mask(x) s1, s2 = size(

我试图在朱莉娅身上建立一个深度学习模型。我有两个模型m1和m2,它们是神经网络。这是我的密码:

using Flux

function even_mask(x)
    s1, s2 = size(x)
    weight_mask = zeros(s1, s2)
    weight_mask[2:2:s1,:] = ones(Int(s1/2), s2)
    return weight_mask
end

function odd_mask(x)
    s1, s2 = size(x)
    weight_mask = zeros(s1, s2)
    weight_mask[1:2:s1,:] = ones(Int(s1/2), s2)
    return weight_mask
end

function even_duplicate(x)
    s1, s2 = size(x)
    x_ = zeros(s1, s2)
    x_[1:2:s1,:] = x[1:2:s1,:]
    x_[2:2:s1,:] = x[1:2:s1,:]
    return x_
end

function odd_duplicate(x)
    s1, s2 = size(x)
    x_ = zeros(s1, s2)
    x_[1:2:s1,:] = x[2:2:s1,:]
    x_[2:2:s1,:] = x[2:2:s1,:]
    return x_
end

function Even(m)
    x -> x .+ even_mask(x).*m(even_duplicate(x))
end

function InvEven(m)
    x -> x .- even_mask(x).*m(even_duplicate(x))
end

function Odd(m)
    x -> x .+ odd_mask(x).*m(odd_duplicate(x))
end

function InvOdd(m)
    x -> x .- odd_mask(x).*m(odd_duplicate(x))
end

m1 = Chain(Dense(4,6,relu), Dense(6,5,relu), Dense(5,4))
m2 = Chain(Dense(4,7,relu), Dense(7,4))

forward = Chain(Even(m1), Odd(m2))
inverse = Chain(InvOdd(m2), InvEven(m1))

function loss(x)
    z = forward(x)
    return 0.5*sum(z.*z)
end

opt = Flux.ADAM()

x = rand(4,100)

for i=1:100
    Flux.train!(loss, Flux.params(forward), x, opt)
    println(loss(x))
end

正向模型是m1和m2的组合。我需要优化m1和m2,这样我就可以优化正向和反向模型。但params(forward)似乎是空的。如何训练我的模型?

我认为普通函数不能用作通量中的层。您需要使用
@functor
宏添加额外的功能来收集参数:

在您的情况下,像这样重写
偶数
偶数
奇数
InvOdd
应该会有所帮助:

struct Even
    model
end

(e::Even)(x) = x .+ even_mask(x).*e.model(even_duplicate(x))

Flux.@functor Even
加上这个定义之后,

Flux.params(Even(m1))
应返回非空列表

编辑

实现
偶数和好友的更简单方法是使用内置层:

我怀疑这是一个版本差异,但是对于Julia 1.4.1和Flux v0.10.4,我得到了错误
BoundsError:trust to access()在索引[1]
中运行训练循环时,我需要将数据替换为

x = [(rand(4,100), 0)]
否则,损失将应用于数组
x
中的每个条目。自从<代码>火车splats
丢失
超过
x

下一个错误
不支持变异数组
是由于执行了
*\u掩码
*\u复制
。这些函数构造一个零数组,然后通过替换输入值对其进行变异

您可以使用来以可以区分的方式实现此代码

using Flux
using Zygote: Buffer

function even_mask(x)
    s1, s2 = size(x)
    weight_mask = Buffer(x)
    weight_mask[2:2:s1,:] = ones(Int(s1/2), s2)
    weight_mask[1:2:s1,:] = zeros(Int(s1/2), s2)
    return copy(weight_mask)
end

function odd_mask(x)
    s1, s2 = size(x)
    weight_mask = Buffer(x)
    weight_mask[2:2:s1,:] = zeros(Int(s1/2), s2)
    weight_mask[1:2:s1,:] = ones(Int(s1/2), s2)
    return copy(weight_mask)
end

function even_duplicate(x)
    s1, s2 = size(x)
    x_ = Buffer(x)
    x_[1:2:s1,:] = x[1:2:s1,:]
    x_[2:2:s1,:] = x[1:2:s1,:]
    return copy(x_)
end

function odd_duplicate(x)
    s1, s2 = size(x)
    x_ = Buffer(x)
    x_[1:2:s1,:] = x[2:2:s1,:]
    x_[2:2:s1,:] = x[2:2:s1,:]
    return copy(x_)
end

Even(m) = SkipConnection(Chain(even_duplicate, m),
                         (mx, x) -> x .+ even_mask(x) .* mx)

InvEven(m) = SkipConnection(Chain(even_duplicate, m),
                            (mx, x) -> x .- even_mask(x) .* mx)

Odd(m) = SkipConnection(Chain(odd_duplicate, m),
                        (mx, x) -> x .+ odd_mask(x) .* mx)

InvOdd(m) = SkipConnection(Chain(odd_duplicate, m),
                           (mx, x) -> x .- odd_mask(x) .* mx)

m1 = Chain(Dense(4,6,relu), Dense(6,5,relu), Dense(5,4))
m2 = Chain(Dense(4,7,relu), Dense(7,4))

forward = Chain(Even(m1), Odd(m2))
inverse = Chain(InvOdd(m2), InvEven(m1))

function loss(x, y)
    z = forward(x)
    return 0.5*sum(z.*z)
end

opt = Flux.ADAM(1e-6)

x = [(rand(4,100), 0)]

function train!()
    for i=1:100
        Flux.train!(loss, Flux.params(forward), x, opt)
        println(loss(x[1]...))
    end
end

在这一点上,你可以享受到深度网络的真正乐趣。在一个训练步骤之后,训练以默认的学习速率发散到
NaN
。将初始训练速率降低到1e-6会有所帮助,而且损失看起来正在减少。

现在我遇到了这个错误:不支持变异数组。
using Flux
using Zygote: Buffer

function even_mask(x)
    s1, s2 = size(x)
    weight_mask = Buffer(x)
    weight_mask[2:2:s1,:] = ones(Int(s1/2), s2)
    weight_mask[1:2:s1,:] = zeros(Int(s1/2), s2)
    return copy(weight_mask)
end

function odd_mask(x)
    s1, s2 = size(x)
    weight_mask = Buffer(x)
    weight_mask[2:2:s1,:] = zeros(Int(s1/2), s2)
    weight_mask[1:2:s1,:] = ones(Int(s1/2), s2)
    return copy(weight_mask)
end

function even_duplicate(x)
    s1, s2 = size(x)
    x_ = Buffer(x)
    x_[1:2:s1,:] = x[1:2:s1,:]
    x_[2:2:s1,:] = x[1:2:s1,:]
    return copy(x_)
end

function odd_duplicate(x)
    s1, s2 = size(x)
    x_ = Buffer(x)
    x_[1:2:s1,:] = x[2:2:s1,:]
    x_[2:2:s1,:] = x[2:2:s1,:]
    return copy(x_)
end

Even(m) = SkipConnection(Chain(even_duplicate, m),
                         (mx, x) -> x .+ even_mask(x) .* mx)

InvEven(m) = SkipConnection(Chain(even_duplicate, m),
                            (mx, x) -> x .- even_mask(x) .* mx)

Odd(m) = SkipConnection(Chain(odd_duplicate, m),
                        (mx, x) -> x .+ odd_mask(x) .* mx)

InvOdd(m) = SkipConnection(Chain(odd_duplicate, m),
                           (mx, x) -> x .- odd_mask(x) .* mx)

m1 = Chain(Dense(4,6,relu), Dense(6,5,relu), Dense(5,4))
m2 = Chain(Dense(4,7,relu), Dense(7,4))

forward = Chain(Even(m1), Odd(m2))
inverse = Chain(InvOdd(m2), InvEven(m1))

function loss(x, y)
    z = forward(x)
    return 0.5*sum(z.*z)
end

opt = Flux.ADAM(1e-6)

x = [(rand(4,100), 0)]

function train!()
    for i=1:100
        Flux.train!(loss, Flux.params(forward), x, opt)
        println(loss(x[1]...))
    end
end