nn_hebbian.h

Go to the documentation of this file.
00001 #include "object.h"
00002 #include "mimasexception.h"
00003 #include "matrix_tool.h"
00004 #include "values.h"
00005 
00006 #ifndef NN_HEBBIAN_H
00007 #define NN_HEBBIAN_H
00008 
00009 namespace mimas {
00016   class nn_hebbian: public object
00017   {
00018 
00019     private:
00020       matrix_tool weights; 
00021       int trainingType; 
00022       double eta; 
00023 
00024     public:
00025       nn_hebbian(void); 
00026       ~nn_hebbian(void); 
00027       void init(int rows, int cols); 
00028       void setTrainingType(int type); 
00029       void setLearningRate(double lr); 
00030          void update(matrix_tool& pat, double targetval) throw (exception); 
00031          double recall(matrix_tool& pat) throw (exception); 
00032       void useWeights(matrix_tool& weight_matrix); 
00033       matrix_tool getWeights(void); 
00034   };
00035 
00036 
00037   nn_hebbian::nn_hebbian()
00038   {
00039     trainingType=0; // default training type
00040     eta=0.01;
00041   }
00042 
00043   void nn_hebbian::init(int rows, int cols)
00044   {
00045     if (weights.initialised()) weights.clear();
00046     weights.init(rows, cols);
00047   }
00048 
00049   void nn_hebbian::setTrainingType(int type)
00050   {
00051     trainingType=type;
00052   }
00053 
00054   void nn_hebbian::setLearningRate(double lr)
00055   {
00056     eta=lr;
00057   }
00058 
00059   // update the Hebbian weights given pattern pat
00060   void nn_hebbian::update(matrix_tool& pat, double targetval)
00061      throw (exception)
00062   {
00063      MMERROR( weights.getRows() ==pat.getRows() &&
00064               weights.getCols() == pat.getCols(),
00065               exception, , "nn_hebbian::update - weight and input "
00066               "matrices dimensions not equal" );
00067 
00068     for (int i=0; i<weights.getRows(); i++)
00069       for (int j=0; j<weights.getCols(); j++)
00070       {
00071         switch (trainingType)
00072         {
00073           // *post-not-pre* long term depression (LTD)
00074           case 1: weights(i,j) = weights(i,j) + eta * (2.0*pat(i,j)-1) * targetval;
00075               break;
00076 
00077               // *pre-not-post* long term depression (LTD)
00078           case 2: weights(i,j) = weights(i,j) + eta * pat(i,j) * (2 * targetval -1);
00079               break;
00080 
00081               // Long term potention (LTP)
00082           default: weights(i,j) = weights(i,j) + eta * pat(i,j) *targetval;
00083         }
00084 
00085       }
00086 
00087   }
00088 
00089 
00090   double nn_hebbian::recall(matrix_tool& pat) throw (exception)
00091   {
00092      MMERROR( weights.getRows() == pat.getRows() &&
00093               weights.getCols() == pat.getCols(), exception, ,
00094               "nn_hebbian::recall - weight and input matrices dimensions "
00095               "not equal" );
00096 
00097     double targetval=0.0;
00098 
00099     for (int i=0; i<weights.getRows(); i++)
00100       for (int j=0; j<weights.getCols(); j++)
00101         targetval=targetval + weights(i,j) * pat(i,j);
00102 
00103     return targetval;
00104 
00105   }
00106 
00107 
00108   // use the weights given by weight_matrix;
00109   void nn_hebbian::useWeights(matrix_tool& weight_matrix)
00110   {
00111     if (weights.initialised()) weights.clear();
00112     weights=weight_matrix;
00113   }
00114 
00115 
00116   // return the weights that are stored 
00117   matrix_tool nn_hebbian::getWeights(void)
00118   {
00119     matrix_tool temp=weights;
00120 
00121     return weights;
00122   }
00123 }
00124 
00125 #endif

[GNU/Linux] [Qt] [Mesa] [STL] [Lapack] [Boost] [Magick++] [Xalan-C and Xerces-C] [doxygen] [graphviz] [FFTW] [popt] [xine] [Gnuplot] [gnu-arch] [gcc] [gstreamer] [autoconf/automake/make] [freshmeat.net] [opensource.org] [sourceforge.net] [MMVL]
mimas 2.1 - Copyright Mon Oct 30 11:31:17 2006, Bala Amavasai, Stuart Meikle, Arul Selvan, Fabio Caparrelli, Jan Wedekind, Manuel Boissenin, ...