Neural Network Class
All lines of code written in this project came from me... and maybe some small pieces from stackoverflow. This project is an attempt at machine learning in java.
The Goals
I wanted to create a class from scratch that would simulate a neural network in order to understand on a deeper level how back propogation and other training algorithms functioned. The project goals were as follows:
- Relatively custom dimensions
- Methods to use the network
- Methods to train the network
- Ability to save the network
- Support for back propogation and "genetic" training algorithms
How It Works
The class is simple to use. When constructing a new NeuralNetwork
you specify the number of neurons in the input layer, the number of hidden layers, the number of neurons in the hidden layer, and the number of output layers.
For example, the following code will create a NeuralNetwork
with 2 input neurons, 1 hidden layer, all hidden layers would have 3 hidden neurons, and 1 output neuron.
NeuralNetwork net = new NeuralNetwork(2,1,3,1);
Note: the dimensions of the hidden layers can be specified by an int array for more customizability.
After you have your NeuralNetwork
instantiated you can specify the inputs and get an output with the feed()
method.
double[] inputs = {1.0,1.0};
double[] outputs = net.feed(inputs);
//Only one output neuron
System.out.printf("Output: %.4f\n", outputs[0]);
The output of this can look something like:
Output: 0.3542
This is great, but unless we want an esoteric random number generator we need to train the network. For brevity, I will only go over using the built-in back-propagation.
In order to train, we first need an array that holds the target values of the output neurons.
double[] targets = {1.0};
We then pass that array into the calcCost()
function as shown:
net.calcCost(targets);
Nothing has actually happened yet, we need to tell the network to adjust it's weights according to a learning rate with the learn()
function.
net.learn(0.2);
Test Code
The learning process must be repeated many times in order for the network to learn what outputs to yield for given inputs. For that, I will show a full program that teaches the network how to perform the XOR operation.
package NeuralNetwork;
/**
* Code which tests the <code>NeuralNetwork</code>.
* It teaches the <code>NeuralNetwork</code> how to
* perform the XOR operation.
* @author Evan Partidas
*
*/
public class NeuralNetworkDriver {
public static void main(String[] args) {
// TODO Auto-generated method stub
NeuralNetwork net = new NeuralNetwork(2,1,3,1);
double[][] inputs =
{
{1.0,1.0},
{0,1.0},
{1.0,0},
{0,0}
};
double[][] target =
{
{0},
{1},
{1},
{0}
};
for(int i=0;i<=200000;i++){
for(int j=0;j<4;j++){
net.feed(inputs[j]);
net.calcCost(target[j]);
net.learn(0.2);
}
if(i%8000==0){
System.out.print("Error: ");
System.out.printf("%.12f \n",net.getError());
}
//Clear the error
net.getError();
}
for(int j=0;j<4;j++){
double[] result = net.feed(inputs[j]);
for(double d:inputs[j]){
System.out.print(d+" ");
}
System.out.print("=");
for(double d:result){
System.out.printf("%.0f ",d);
System.out.printf(" %.5f \n",d);
}
}
}
}
The Results
Here is the output of the code from the previous section:
Error: 1.014875502988
Error: 1.004292824064
Error: 1.002261767357
Error: 0.997474268213
Error: 0.985190951483
Error: 0.895985322637
Error: 0.473909484511
Error: 0.318087756888
Error: 0.285337725086
Error: 0.272867228348
Error: 0.266365818334
Error: 0.262274875550
Error: 0.259246936556
Error: 0.256376909589
Error: 0.250385037883
Error: 0.039134013911
Error: 0.008678194293
Error: 0.005036784148
Error: 0.003684144740
Error: 0.002964742038
Error: 0.002508633338
Error: 0.002188276473
Error: 0.001947972870
Error: 0.001759400334
Error: 0.001606529728
Error: 0.001479538593
1.0 1.0 =0 0.01168
0.0 1.0 =1 0.97895
1.0 0.0 =1 0.97828
0.0 0.0 =0 0.02065
As can be seen, the network indeed learned how to perform an XOR. The outputs get closer and closer to the expected output we want. This can also be shown as the error goes down.
The project was a huge success. After many days of designing, learning, and coding I now have my own custom-made NeuralNetwork class.
The Code
The source code for the entire project can be found here
The code is divided into 4 Classes:
-
NeuralNetwork
- Main class that houses the data and methods -
Neuron
- Necessary class that helps with Forward and Back Propogation -
NeuralCache
- Utility class which stores all the matrix data within a Neural Network -
NeuralModel
- Utility class which stores a possible structure of the Neural Network
Note: the fourth NeuralModel
class is not necessary for the stated goals. However, I found the class to be handy
You can learn a lot more about the source code by visiting it with the link given above.