I would like to get a simple example running in matlab that will use a neural net to learn an arbitrary function from input output data (basically model identification) and then be able to approximate that function from just the input data. As means of training this net I have implemented a simple back propagation algorithm in matlab but I was not able to get anywhere close to satisfactory results. I would like to know what I may be doing wrong and also what approach I may use instead.
The goal is to have the network represent an identified function f(x) which takes a series x as input and outputs the learned mapping from x -> y.
Here is the GNU octave code I have so far:
pkg load control signal
function r = sigmoid(z)
r = 1 ./ (1 + exp(-z));
end
function r = linear(z)
r = z;
end
function r = grad_sigmoid(z)
r = sigmoid(z) .* (1 - sigmoid(z));
end
function r = grad_linear(z)
r = 1;
end
function r = grad_tanh(z)
r = 1 - tanh(z) .^ 2;
end
function nn = nn_init(n_input, n_hidden1, n_hidden2, n_output)
nn.W2 = (rand(n_input, n_hidden1) * 2 - 1)'
nn.W3 = (rand(n_hidden1, n_hidden2) * 2 - 1)'
nn.W4 = (rand(n_hidden2, n_output) * 2 - 1)'
nn.lambda = 0.005;
end
function nn = nn_train(nn_in, state, action)
nn = nn_in;
[out, nn] = nn_eval(nn, state);
d4 = (nn.a4 - action) .* grad_linear(nn.W4 * nn.a3);
d3 = (nn.W4' * d4) .* grad_tanh(nn.W3 * nn.a2);
d2 = (nn.W3' * d3) .* grad_tanh(nn.W2 * nn.a1);
nn.W4 -= nn.lambda * (d4 * nn.a3');
nn.W3 -= nn.lambda * (d3 * nn.a2');
nn.W2 -= nn.lambda * (d2 * nn.a1');
end
function [out,nn] = nn_eval(nn_in, state)
nn = nn_in;
nn.z1 = state;
nn.a1 = nn.z1;
nn.a2 = tanh(nn.W2 * nn.a1);
nn.a3 = tanh(nn.W3 * nn.a2);
nn.a4 = linear(nn.W4 * nn.a3);
out = nn.a4;
end
nn = nn_init(1, 100, 100, 1);
t = 1:0.1:3.14*10;
input = t;
output = sin(input);
learned = zeros(1, length(output));
for j = 1:500
for i = 1:length(input)
nn = nn_train(nn, [input(i)], [output(i)]);
end
j
end
for i = 1:length(input)
learned(i) = nn_eval(nn, [input(i)]);
end
plot(t, output, 'g', t, learned, 'b');
pause
Here is the result:
Image may be NSFW.
Clik here to view.
The result is not even close to where I want it to be. Has it got something to do with my implementation of back propagation?
What changes do I need to do to the code to get a better approximation going?