00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef CANALYZER_H
00033 #define CANALYZER_H
00034
00035 #define ANALYZER_MSE 1
00036 #define ANALYZER_MAE 2
00037 #define ANALYZER_MAXERROR 3
00038
00039 #include "cbatchlearning.h"
00040 #include "ctestsuit.h"
00041 #include "crewardfunction.h"
00042 #include "cagentcontroller.h"
00043 #include "cbaseobjects.h"
00044
00045 #include <list>
00046
00047 class CAbstractQFunction;
00048 class CAbstractVFunction;
00049
00051
00063 class CVFunctionAnalyzer
00064 {
00065 protected:
00066 CAbstractVFunction *vFunction;
00067 CStateProperties *modelStateProperties;
00068 CStateCollectionImpl *stateCollection;
00069 public:
00071
00074 CVFunctionAnalyzer(CAbstractVFunction *vFunction, CStateProperties *modelState, std::list<CStateModifier *> *modifiers);
00075 virtual ~CVFunctionAnalyzer();
00076
00078
00079 void save1DValues(FILE *stream, CState *initstate, int dim1, int part1);
00080
00082
00083 void save2DValues(FILE *stream, CState *initstate, int dim1, int part1, int dim2, int part2);
00084
00086 void saveStateValues(FILE *stream, CStateList *states);
00087
00089
00090 void save1DValues(char *filename, CState *initstate, int dim1, int part1);
00091
00093
00094 void save2DValues(char *filename, CState *initstate, int dim1, int part1, int dim2, int part2);
00095
00097 void saveStateValues(char *filename, CStateList *states);
00098
00099
00100 void setVFunction(CAbstractVFunction *vFunction);
00101 };
00102
00103
00105
00117 class CQFunctionAnalyzer
00118 {
00119 protected:
00120 CAbstractQFunction *qFunction;
00121 CStateProperties *modelStateProperties;
00122 CStateCollectionImpl *stateCollection;
00123 public:
00125
00128 CQFunctionAnalyzer(CAbstractQFunction *qFunction, CStateProperties *modelState, std::list<CStateModifier *> *modifiers);
00129 virtual ~CQFunctionAnalyzer();
00130
00131 void setQFunction(CAbstractQFunction *l_qFunction);
00132
00134
00136 void save1DValues(FILE *stream, CActionSet *action, CState *initstate, int dim1, int part1);
00137
00139
00140 void save2DValues(FILE *stream, CActionSet *action, CState *initstate, int dim1, int part1, int dim2, int part2);
00141
00142 void save2DValues(char *filename, CActionSet *action, CState *initstate, int dim1, int part1, int dim2, int part2);
00143
00145
00146 void saveStateValues(FILE *stream, CActionSet *action, CStateList *states);
00147 };
00148
00150
00161 class CFunctionComperator : public CRewardFunction
00162 {
00163 protected:
00164 CStateProperties *modelStateProperties;
00165 CStateCollectionImpl *stateCollection;
00166
00168 virtual double getValue(int numFunc, CStateCollection *state) = 0;
00170 virtual double getDifference(CStateCollection *state, int errorFunction);
00171
00173 void getRandomState(CState *state);
00174 public:
00176
00179 CFunctionComperator(CStateProperties *modelState, std::list<CStateModifier *> *modifiers);
00180 virtual ~CFunctionComperator();
00181
00183
00186 double compareFunctionsRandom(int nSamples, int errorFunction = 1);
00187
00189
00192 double compareFunctionsStates(CStateList *states, int errorFunction = 1);
00193
00194 double getReward(CStateCollection *oldState, CAction *action, CStateCollection *newState);
00195 };
00196
00198
00201 class CVFunctionComperator : public CFunctionComperator
00202 {
00203 protected:
00204 CAbstractVFunction *vFunction1;
00205 CAbstractVFunction *vFunction2;
00206
00207 virtual double getValue(int numFunc, CStateCollection *state);
00208
00209 public:
00210 CVFunctionComperator(CStateProperties *modelState, std::list<CStateModifier *> *modifiers, CAbstractVFunction *vFunction1, CAbstractVFunction *vFunction2);
00211 virtual ~CVFunctionComperator(){};
00212 };
00213
00215
00219 class CQFunctionComperator : public CFunctionComperator
00220 {
00221 protected:
00222 CAbstractQFunction *qFunction1;
00223 CAbstractQFunction *qFunction2;
00224 CAction *action;
00225
00226 virtual double getValue(int numFunc, CStateCollection *state);
00227
00228 public:
00229 CQFunctionComperator(CStateProperties *modelState, std::list<CStateModifier *> *modifiers, CAbstractQFunction *qFunction1, CAbstractQFunction *qFunction2, CAction *action);
00230
00231 virtual ~CQFunctionComperator(){};
00232 };
00233
00235
00238 class CControllerAnalyzer : public CActionObject
00239 {
00240 protected:
00241 CStateList *states;
00242 CAgentController *controller;
00243
00244 public:
00245 CControllerAnalyzer(CStateList *states, CAgentController *controller, CActionSet *actions);
00246 virtual ~CControllerAnalyzer();
00247
00248 CStateList *getStateList();
00249 void setStateList(CStateList *states);
00250
00251 CAgentController *getController();
00252 void setController(CAgentController *Controller);
00253
00254 void saveActions(FILE *stream, std::list<CStateModifier *> *modifiers);
00255 };
00256
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281 class CFittedQIterationAnalyzer : public CFittedQIteration, public CTestSuiteEvaluatorLogger
00282 {
00283 protected:
00284 CPolicySameStateEvaluator *evaluator;
00285
00286 int numEvaluations;
00287 FILE *analyzerFile;
00288
00289 CState *buffState2;
00290 bool useQValues;
00291
00292 double lastQValue;
00293 double lastEstimatedQValue;
00294
00295 virtual double getValue(CStateCollection *stateCollection, CAction *action);
00296 public:
00297 CFittedQIterationAnalyzer(CQFunction *qFunction, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSupervisedQFunctionLearner *learner, CStateProperties *residualProperties, CPolicySameStateEvaluator *evaluator);
00298
00299 virtual ~CFittedQIterationAnalyzer();
00300
00301 virtual void addResidualInput(CStep *step, CAction *action, double V, double newV, double nearestNeighborDistance, CAction *nextHistoryAction = NULL, double nextReward = 0.0);
00302
00303 virtual void evaluate(string evaluationDirectory, int trial, int numEpisodes);
00304
00305 virtual void startNewEvaluation(string evaluationDirectory, CParameters *parameters, int trial);
00306
00307 };
00308
00309
00310 #endif // CQFUNCTIONANALYZER_H
00311
00312