00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef CVTORCHFUNCTION_H
00033 #define CVTORCHFUNCTION_H
00034
00035 #include "ConnectedMachine.h"
00036 #include "MLP.h"
00037 #include "cvfunction.h"
00038 #include "ccontinuousactions.h"
00039
00040
00041 using Torch::Sequence;
00042 using Torch::Parameters;
00043 using Torch::Machine;
00044 using Torch::GradientMachine;
00045
00047
00056 class CTorchFunction
00057 {
00058 protected:
00060 Machine *machine;
00061
00063 Sequence *input;
00064
00065 public:
00067 CTorchFunction(Machine *machine);
00068 virtual ~CTorchFunction();
00069
00070
00071
00073 virtual Machine *getMachine();
00074
00075 virtual double getValueFromMachine(Sequence *state);
00076 };
00077
00079
00085 class CTorchGradientFunction : public CTorchFunction, public CGradientFunction
00086 {
00087 protected:
00088 Sequence *alpha;
00090 GradientMachine *gradientMachine;
00091
00092 CAdaptiveEtaCalculator *localEtaCalc;
00093 public:
00095 CTorchGradientFunction(int numInputs, int numOutputs);
00096 CTorchGradientFunction(GradientMachine *machine);
00097 virtual ~CTorchGradientFunction();
00098
00100 virtual void resetData();
00101
00102 virtual void updateWeights(CFeatureList *gradientFeatures);
00103
00104 virtual int getNumWeights();
00105
00106 virtual void getInputDerivationPre(ColumnVector *input, Matrix *targetVector);
00107 virtual void getFunctionValuePre(ColumnVector *input, ColumnVector *output);
00108
00109
00110 virtual void getWeights(double *parameters);
00111 virtual void setWeights(double *parameters);
00112
00113 virtual void getGradientPre(ColumnVector *input, ColumnVector *outputErrors, CFeatureList *gradientFeatures);
00114
00115 void setGradientMachine(GradientMachine *gradientMachine);
00116 GradientMachine *getGradientMachine();
00117 };
00118
00119 class CTorchGradientEtaCalculator : public CIndividualEtaCalculator
00120 {
00121 public:
00122 CTorchGradientEtaCalculator(GradientMachine *gradientMachine);
00123 };
00124
00125 class CTorchVFunction : public CAbstractVFunction
00126 {
00127 protected:
00129
00131 void getInputSequence(CState *state, Sequence *input);
00132
00133 CTorchFunction *torchFunction;
00134
00135 Sequence *input;
00136
00137 public:
00139 CTorchVFunction(CTorchFunction *torchFunction, CStateProperties *properties);
00140 virtual ~CTorchVFunction();
00141
00143 virtual double getValue(CState *state);
00144
00145 };
00146
00147
00149
00155 class CVFunctionFromGradientFunction : public CGradientVFunction, public CVFunctionInputDerivationCalculator
00156 {
00157 protected:
00159 CGradientFunction *gradientFunction;
00160
00161 ColumnVector *input;
00162 ColumnVector *outputError;
00163 Matrix *inputDerivation;
00164
00165 virtual void updateWeights(CFeatureList *gradientFeatures);
00166 void getInputSequence(CState *state, ColumnVector *sequence);
00167
00168 public:
00170 CVFunctionFromGradientFunction(CGradientFunction *gradientFunction, CStateProperties *properties);
00171 virtual ~CVFunctionFromGradientFunction();
00172
00174
00175 virtual void setValue(CState *state, double value);
00177
00178
00180 virtual void resetData();
00181
00183 virtual double getValue(CState *state);
00184
00185
00186
00187 virtual void getGradient(CStateCollection *originalState, CFeatureList *modifiedState);
00188
00189 virtual int getNumWeights();
00190
00191 virtual CAbstractVETraces *getStandardETraces();
00192
00193 void getInputDerivation(CStateCollection *originalState, ColumnVector *targetVector);
00194
00195 virtual void getWeights(double *parameters);
00196 virtual void setWeights(double *parameters);
00197
00198 };
00199
00200 class CQFunctionFromGradientFunction : public CContinuousActionQFunction, CStateObject
00201 {
00202 protected:
00203 CGradientFunction *gradientFunction;
00204 ColumnVector *input;
00205 ColumnVector *outputError;
00206
00207 CActionSet *staticActions;
00208
00209 void getInputSequence(ColumnVector *input, CState *state, CContinuousActionData *data);
00210 virtual void updateWeights(CFeatureList *gradientFeatures);
00211
00212 public:
00213 CQFunctionFromGradientFunction(CContinuousAction *contAction, CGradientFunction *torchGradientFunction, CActionSet *actions, CStateProperties *properties);
00214 virtual ~CQFunctionFromGradientFunction();
00215
00216 virtual void getBestContinuousAction(CStateCollection *state, CContinuousActionData *actionData);
00217
00218 virtual void updateCAValue(CStateCollection *state, CContinuousActionData *data, double td);
00219 virtual void setCAValue(CStateCollection *state, CContinuousActionData *data, double qValue);
00220 virtual double getCAValue(CStateCollection *state, CContinuousActionData *data);
00221
00222
00223 virtual void getCAGradient(CStateCollection *state, CContinuousActionData *data, CFeatureList *gradient);
00224 virtual int getNumWeights();
00225
00226 virtual void getWeights(double *parameters);
00227 virtual void setWeights(double *parameters);
00228
00229 virtual void resetData();
00230 };
00231
00232
00233
00234 #endif