I'm trying to build a neural network that can answer to the xor problem. My code is the following:
using MXNet
using Distributions
using PyPlot
xor_data = zeros(4,2)
xor_data[1:0] = 1
xor_data[1:1] = 1
xor_data[2:0] = 1
xor_data[2:1] = 0
xor_data[3:0] = 0
xor_data[3:1] = 1
xor_data[4:0] = 0
xor_data[4:1] = 0
xor_labels = zeros(4)
xor_labels[1] = 0
xor_labels[2] = 1
xor_labels[3] = 1
xor_labels[4] = 0
batchsize = 4
trainprovider = mx.ArrayDataProvider(:data => xor_data, batch_size=batchsize, shuffle=true, :label => xor_labels)
evalprovider = mx.ArrayDataProvider(:data => xor_data, batch_size=batchsize, shuffle=true, :label => xor_labels)
data = mx.Variable(:data)
label = mx.Variable(:label)
net = @mx.chain mx.Variable(:data) =>
mx.FullyConnected(num_hidden=2) =>
mx.Activation(act_type=:relu) =>
mx.FullyConnected(num_hidden=2) =>
mx.Activation(act_type=:relu) =>
mx.FullyConnected(num_hidden=1) =>
mx.Activation(act_type=:relu) =>
model = mx.FeedForward(net, context=mx.cpu())
optimizer = mx.SGD(lr=0.01, momentum=0.9, weight_decay=0.00001)
initializer = mx.NormalInitializer(0.0,0.1)
eval_metric = mx.MSE()
mx.fit(model, optimizer, initializer, eval_metric, trainprovider, eval_data = evalprovider, n_epoch = 100)
mx.fit(model, optimizer, eval_metric, trainprovider, eval_data = evalprovider, n_epoch = 100)
But I'm getting the following error:
LoadError: AssertionError: Number of samples in label is mismatch with data in expression starting on line 22 in #ArrayDataProvider#6428(::Int64, ::Bool, ::Int64, ::Int64, ::Type{T}, ::Pair{Symbol,Array{Float64,2}}, ::Pair{Symbol,Array{Float64,1}}) at io.jl:324 in (::Core.#kw#Type)(::Array{Any,1}, ::Type{MXNet.mx.ArrayDataProvider}, ::Pair{Symbol,Array{Float64,2}}, ::Pair{Symbol,Array{Float64,1}}) at :0 in include_string(::String, ::String) at loading.jl:441 in include_string(::String, ::String) at sys.dylib:? in include_string(::Module, ::String, ::String) at eval.jl:32 in (::Atom.##59#62{String,String})() at eval.jl:81 in withpath(::Atom.##59#62{String,String}, ::String) at utils.jl:30 in withpath(::Function, ::String) at eval.jl:46 in macro expansion at eval.jl:79 [inlined] in (::Atom.##58#61{Dict{String,Any}})() at task.jl:60
I want to feed to the network to values (0 or 1) and get a single value. Were is my error?