00001
00002
00003
00004
00005
00006 #ifndef NN_MLP_TOOL_H
00007 #define NN_MLP_TOOL_H
00008
00009 #include <boost/numeric/ublas/matrix.hpp>
00010 #include <boost/numeric/ublas/vector.hpp>
00011 #include <boost/shared_ptr.hpp>
00012 #include <cmath>
00013 #include <stdarg.h>
00014 #include "object.h"
00015 #include "mimasexception.h"
00016
00017 namespace mimas {
00018
00019
00028 class nn_mlp_tool : public object
00029 {
00030 public:
00031 typedef boost::numeric::ublas::matrix< double > Matrix;
00032 typedef boost::numeric::ublas::vector< double > Vector;
00033 private:
00034 typedef double (nn_mlp_tool::*ptrToFuncType)(double);
00035 boost::shared_array< ptrToFuncType > ptrToTransferFunc;
00036 boost::shared_array< ptrToFuncType > ptrToTransferFuncDeriv;
00037 boost::shared_array< Matrix > weight;
00038 boost::shared_array< Vector > bias;
00039 boost::shared_array< Vector > y;
00040 boost::shared_array< Vector > net;
00041 boost::shared_array< Vector > delta;
00042 boost::shared_array< Matrix > prevWeightChange;
00043 boost::shared_array< Vector > prevBiasChange;
00044 boost::shared_array< int > numnodes;
00045 bool network_initialised;
00046 bool debugtrain;
00047 int numlayers;
00048
00049 int iterations;
00050 void init(void);
00051 double deltaRule( Matrix& pat, int p) throw (mimasexception);
00052
00053 public:
00054 nn_mlp_tool(void);
00055 ~nn_mlp_tool(void);
00056 void init(int nodes, ...) throw (mimasexception);
00057 void clear(void);
00058 void showConfig(void);
00059 Matrix feedForward( Matrix& pat) throw (mimasexception);
00060
00061
00062 void save (const char *fn) throw (mimasexception);
00063 void load (const char *fn) throw (mimasexception);
00064
00065
00066
00067 void setIterations(int i) throw (mimasexception);
00068 int getIterations(void);
00069
00070
00071 Matrix trainOnline( Matrix& pat, double learningRate=0.3, double momentum=0.05, double weightDecay=0.0) throw (mimasexception);
00072 Matrix trainBatch(Matrix& pat, double learningRate=0.3, double momentum=0.05, double weightDecay=0.0) throw (mimasexception);
00073 Matrix trainQprop(Matrix& pat, double learningRate=0.0005, double momentum=0.05, double maxFactor=2.0) throw (mimasexception);
00074 Matrix trainRprop(Matrix& pat, double etapos=1.2, double etaneg=0.5) throw (mimasexception);
00075
00076
00077 void setTransferFuncLayer(ptrToFuncType newfunc, int layer) throw (mimasexception);
00078 void setTransferFuncDerivLayer(ptrToFuncType newfunc, int layer) throw (mimasexception);
00079
00080
00081 virtual double transferFunc(double val);
00082 virtual double transferFuncDeriv(double val);
00083 double transferFuncTanh(double val);
00084 double transferFuncDerivTanh(double val);
00085 double transferFuncLinear(double val);
00086 double transferFuncDerivLinear(double val);
00087
00088
00089 void setDebug(bool val);
00090
00091 };
00092
00093 }
00094 #endif
00095
00096