00001 #include "object.h"
00002 #include "mimasexception.h"
00003 #include "matrix_tool.h"
00004 #include "values.h"
00005
00006 #ifndef NN_HEBBIAN_H
00007 #define NN_HEBBIAN_H
00008
00009 namespace mimas {
00016 class nn_hebbian: public object
00017 {
00018
00019 private:
00020 matrix_tool weights;
00021 int trainingType;
00022 double eta;
00023
00024 public:
00025 nn_hebbian(void);
00026 ~nn_hebbian(void);
00027 void init(int rows, int cols);
00028 void setTrainingType(int type);
00029 void setLearningRate(double lr);
00030 void update(matrix_tool& pat, double targetval) throw (exception);
00031 double recall(matrix_tool& pat) throw (exception);
00032 void useWeights(matrix_tool& weight_matrix);
00033 matrix_tool getWeights(void);
00034 };
00035
00036
00037 nn_hebbian::nn_hebbian()
00038 {
00039 trainingType=0;
00040 eta=0.01;
00041 }
00042
00043 void nn_hebbian::init(int rows, int cols)
00044 {
00045 if (weights.initialised()) weights.clear();
00046 weights.init(rows, cols);
00047 }
00048
00049 void nn_hebbian::setTrainingType(int type)
00050 {
00051 trainingType=type;
00052 }
00053
00054 void nn_hebbian::setLearningRate(double lr)
00055 {
00056 eta=lr;
00057 }
00058
00059
00060 void nn_hebbian::update(matrix_tool& pat, double targetval)
00061 throw (exception)
00062 {
00063 MMERROR( weights.getRows() ==pat.getRows() &&
00064 weights.getCols() == pat.getCols(),
00065 exception, , "nn_hebbian::update - weight and input "
00066 "matrices dimensions not equal" );
00067
00068 for (int i=0; i<weights.getRows(); i++)
00069 for (int j=0; j<weights.getCols(); j++)
00070 {
00071 switch (trainingType)
00072 {
00073
00074 case 1: weights(i,j) = weights(i,j) + eta * (2.0*pat(i,j)-1) * targetval;
00075 break;
00076
00077
00078 case 2: weights(i,j) = weights(i,j) + eta * pat(i,j) * (2 * targetval -1);
00079 break;
00080
00081
00082 default: weights(i,j) = weights(i,j) + eta * pat(i,j) *targetval;
00083 }
00084
00085 }
00086
00087 }
00088
00089
00090 double nn_hebbian::recall(matrix_tool& pat) throw (exception)
00091 {
00092 MMERROR( weights.getRows() == pat.getRows() &&
00093 weights.getCols() == pat.getCols(), exception, ,
00094 "nn_hebbian::recall - weight and input matrices dimensions "
00095 "not equal" );
00096
00097 double targetval=0.0;
00098
00099 for (int i=0; i<weights.getRows(); i++)
00100 for (int j=0; j<weights.getCols(); j++)
00101 targetval=targetval + weights(i,j) * pat(i,j);
00102
00103 return targetval;
00104
00105 }
00106
00107
00108
00109 void nn_hebbian::useWeights(matrix_tool& weight_matrix)
00110 {
00111 if (weights.initialised()) weights.clear();
00112 weights=weight_matrix;
00113 }
00114
00115
00116
00117 matrix_tool nn_hebbian::getWeights(void)
00118 {
00119 matrix_tool temp=weights;
00120
00121 return weights;
00122 }
00123 }
00124
00125 #endif