00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef C_GRADIENTFUNCTION__H
00033 #define C_GRADIENTFUNCTION__H
00034
00035 #include "cparameters.h"
00036 #include "clearndataobject.h"
00037 #include <newmat/newmat.h>
00038
00039 class CFeatureList;
00040
00042
00049 class CAdaptiveEtaCalculator : virtual public CParameterObject
00050 {
00051 public:
00053 virtual void getWeightUpdates(CFeatureList *updates) = 0;
00054 };
00055
00057 class CIndividualEtaCalculator : public CAdaptiveEtaCalculator
00058 {
00059 protected:
00060 int numWeights;
00061 double *etas;
00062 public:
00063 CIndividualEtaCalculator(int numWeights, double *etas = NULL);
00064 virtual ~CIndividualEtaCalculator();
00065
00067 virtual void getWeightUpdates(CFeatureList *updates);
00068
00070 virtual void setEta(int index, double value);
00071 };
00072
00074
00083 class CVarioEta : public CAdaptiveEtaCalculator
00084 {
00085 protected:
00086 double *eta_i;
00087 double *v_i;
00088
00089
00090
00091
00092 unsigned int numParams;
00093 public:
00094 CVarioEta(unsigned int numParams, double eta, double beta = 0.01, double epsilon = 0.0001);
00095 ~CVarioEta();
00096
00097 virtual void getWeightUpdates(CFeatureList *updates);
00098 };
00099
00101
00115 class CGradientUpdateFunction : virtual public CParameterObject, virtual public CLearnDataObject
00116 {
00117 protected:
00118 CFeatureList *localGradientFeatureBuffer;
00119
00120
00121 CAdaptiveEtaCalculator *etaCalc;
00122 public:
00123 CGradientUpdateFunction();
00124 virtual ~CGradientUpdateFunction();
00125
00127
00131 void updateGradient(CFeatureList *gradientFeatures, double factor = 1.0);
00132
00134 virtual void updateWeights(CFeatureList *dParams) = 0;
00135
00137 virtual int getNumWeights() = 0;
00138
00140 virtual CAdaptiveEtaCalculator* getEtaCalculator();
00142 virtual void setEtaCalculator(CAdaptiveEtaCalculator *etaCalc);
00143
00145
00148 virtual void getWeights(double *parameters) = 0;
00150
00153 virtual void setWeights(double *parameters) = 0;
00154
00156 virtual void saveData(FILE *stream);
00158 virtual void loadData(FILE *stream);
00159
00161 virtual void resetData() = 0;
00162
00163 virtual void copy(CLearnDataObject *gradientFuntion);
00164 };
00165
00166
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00186
00187
00188
00189
00190
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00219
00228 class CGradientFunction : public CGradientUpdateFunction
00229 {
00230 protected:
00231 int num_inputs;
00232 int num_outputs;
00233
00234 ColumnVector *input_mean;
00235 ColumnVector *input_std;
00236
00237 ColumnVector *output_mean;
00238 ColumnVector *output_std;
00239
00240
00241 virtual void preprocessInput(ColumnVector *input, ColumnVector *norm_input);
00242 virtual void postprocessOutput(Matrix *norm_output, Matrix *output);
00243 public:
00244 CGradientFunction(int n_input, int n_output);
00245 virtual ~CGradientFunction();
00246
00247 virtual void getGradient(ColumnVector *input, ColumnVector *outputErrors, CFeatureList *gradientFeatures);
00249 virtual void getFunctionValue(ColumnVector *input, ColumnVector *output);
00250
00252 virtual void getInputDerivation(ColumnVector *input, Matrix *targetVector);
00253
00254
00255
00257 virtual void getGradientPre(ColumnVector *input, ColumnVector *outputErrors, CFeatureList *gradientFeatures) = 0;
00259 virtual void getFunctionValuePre(ColumnVector *input, ColumnVector *output) = 0;
00260
00262 virtual void getInputDerivationPre(ColumnVector *, Matrix *) {};
00263
00264
00266 virtual int getNumInputs();
00268 virtual int getNumOutputs();
00269
00270 void setInputMean(ColumnVector *input_mean);
00271 void setOutputMean(ColumnVector *output_mean);
00272
00273 void setInputStd(ColumnVector *input_std);
00274 void setOutputStd(ColumnVector *output_std);
00275 };
00276
00277
00278 #endif
00279