nn_mlp_tool.h

Go to the documentation of this file.
00001 //
00002 // Multi-layered perceptron tool class
00003 //
00004 
00005 
00006 #ifndef NN_MLP_TOOL_H
00007 #define NN_MLP_TOOL_H
00008 
00009 #include <boost/numeric/ublas/matrix.hpp>
00010 #include <boost/numeric/ublas/vector.hpp>
00011 #include <boost/shared_ptr.hpp>
00012 #include <cmath>
00013 #include <stdarg.h>
00014 #include "object.h"
00015 #include "mimasexception.h"
00016 
00017 namespace mimas {
00018 
00019 
00028 class nn_mlp_tool : public object
00029 {
00030    public:
00031       typedef boost::numeric::ublas::matrix< double > Matrix;
00032       typedef boost::numeric::ublas::vector< double > Vector;
00033   private:
00034     typedef double (nn_mlp_tool::*ptrToFuncType)(double); 
00035       boost::shared_array< ptrToFuncType > ptrToTransferFunc;
00036     boost::shared_array< ptrToFuncType > ptrToTransferFuncDeriv;
00037       boost::shared_array< Matrix > weight; 
00038     boost::shared_array< Vector > bias; 
00039     boost::shared_array< Vector > y; 
00040     boost::shared_array< Vector > net; 
00041     boost::shared_array< Vector > delta; 
00042     boost::shared_array< Matrix > prevWeightChange; 
00043     boost::shared_array< Vector > prevBiasChange; 
00044       boost::shared_array< int > numnodes; 
00045     bool network_initialised; 
00046     bool debugtrain; 
00047     int numlayers; 
00048     //double learningRate, momentum; ///< learning rate and momentum
00049     int iterations; 
00050     void init(void); 
00051     double deltaRule( Matrix& pat, int p) throw (mimasexception); // perform delta rule on row p of pattern matrix pat
00052 
00053   public:
00054     nn_mlp_tool(void); 
00055     ~nn_mlp_tool(void); 
00056     void init(int nodes, ...) throw (mimasexception); 
00057     void clear(void); 
00058     void showConfig(void); 
00059       Matrix feedForward( Matrix& pat) throw (mimasexception); 
00060 
00061     // io functions
00062     void save (const char *fn) throw (mimasexception); 
00063     void load (const char *fn) throw (mimasexception); 
00064 
00065 
00066     // accessor functions
00067     void setIterations(int i) throw (mimasexception); 
00068     int getIterations(void); 
00069     
00070     // training methods
00071       Matrix trainOnline( Matrix& pat, double learningRate=0.3, double momentum=0.05, double weightDecay=0.0) throw (mimasexception); 
00072       Matrix trainBatch(Matrix& pat, double learningRate=0.3, double momentum=0.05, double weightDecay=0.0) throw (mimasexception); 
00073     Matrix trainQprop(Matrix& pat, double learningRate=0.0005, double momentum=0.05, double maxFactor=2.0) throw (mimasexception); 
00074       Matrix trainRprop(Matrix& pat, double etapos=1.2, double etaneg=0.5) throw (mimasexception); 
00075 
00076     // set transfer function for a layers
00077     void setTransferFuncLayer(ptrToFuncType newfunc, int layer) throw (mimasexception); 
00078     void setTransferFuncDerivLayer(ptrToFuncType newfunc, int layer) throw (mimasexception); 
00079     
00080     // user selectable transfer functions and the derivatives of the functions
00081     virtual double transferFunc(double val); 
00082     virtual double transferFuncDeriv(double val); 
00083     double transferFuncTanh(double val); 
00084     double transferFuncDerivTanh(double val); 
00085     double transferFuncLinear(double val); 
00086     double transferFuncDerivLinear(double val); 
00087 
00088     // debugging options
00089     void setDebug(bool val); 
00090   
00091 };
00092 
00093 }
00094 #endif
00095 
00096 

[GNU/Linux] [Qt] [Mesa] [STL] [Lapack] [Boost] [Magick++] [Xalan-C and Xerces-C] [doxygen] [graphviz] [FFTW] [popt] [xine] [Gnuplot] [gnu-arch] [gcc] [gstreamer] [autoconf/automake/make] [freshmeat.net] [opensource.org] [sourceforge.net] [MMVL]
mimas 2.1 - Copyright Mon Oct 30 11:31:17 2006, Bala Amavasai, Stuart Meikle, Arul Selvan, Fabio Caparrelli, Jan Wedekind, Manuel Boissenin, ...